xiaoanyu123 commited on
Commit
44fba03
·
verified ·
1 Parent(s): 8247620

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/.venv/Lib/site-packages/diffusers/commands/__init__.py +27 -0
  2. pythonProject/.venv/Lib/site-packages/diffusers/commands/__pycache__/__init__.cpython-310.pyc +0 -0
  3. pythonProject/.venv/Lib/site-packages/diffusers/commands/__pycache__/custom_blocks.cpython-310.pyc +0 -0
  4. pythonProject/.venv/Lib/site-packages/diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc +0 -0
  5. pythonProject/.venv/Lib/site-packages/diffusers/commands/__pycache__/env.cpython-310.pyc +0 -0
  6. pythonProject/.venv/Lib/site-packages/diffusers/commands/__pycache__/fp16_safetensors.cpython-310.pyc +0 -0
  7. pythonProject/.venv/Lib/site-packages/diffusers/commands/diffusers_cli.py +45 -0
  8. pythonProject/.venv/Lib/site-packages/diffusers/commands/env.py +180 -0
  9. pythonProject/.venv/Lib/site-packages/diffusers/commands/fp16_safetensors.py +132 -0
  10. pythonProject/.venv/Lib/site-packages/diffusers/experimental/__init__.py +1 -0
  11. pythonProject/.venv/Lib/site-packages/diffusers/experimental/rl/__init__.py +1 -0
  12. pythonProject/.venv/Lib/site-packages/diffusers/experimental/rl/value_guided_sampling.py +153 -0
  13. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_deis_multistep.py +893 -0
  14. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep.py +1180 -0
  15. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +645 -0
  16. pythonProject/.venv/Lib/site-packages/functorch/compile/__init__.py +30 -0
  17. pythonProject/.venv/Lib/site-packages/functorch/compile/__pycache__/__init__.cpython-310.pyc +0 -0
  18. pythonProject/.venv/Lib/site-packages/functorch/dim/batch_tensor.py +25 -0
  19. pythonProject/.venv/Lib/site-packages/functorch/dim/delayed_mul_tensor.py +77 -0
  20. pythonProject/.venv/Lib/site-packages/functorch/dim/dim.py +121 -0
pythonProject/.venv/Lib/site-packages/diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
pythonProject/.venv/Lib/site-packages/diffusers/commands/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (805 Bytes). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/commands/__pycache__/custom_blocks.cpython-310.pyc ADDED
Binary file (4.43 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc ADDED
Binary file (920 Bytes). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/commands/__pycache__/env.cpython-310.pyc ADDED
Binary file (4.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/commands/__pycache__/fp16_safetensors.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .custom_blocks import CustomBlocksCommand
19
+ from .env import EnvironmentCommand
20
+ from .fp16_safetensors import FP16SafetensorsCommand
21
+
22
+
23
+ def main():
24
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
25
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
26
+
27
+ # Register commands
28
+ EnvironmentCommand.register_subcommand(commands_parser)
29
+ FP16SafetensorsCommand.register_subcommand(commands_parser)
30
+ CustomBlocksCommand.register_subcommand(commands_parser)
31
+
32
+ # Let's go
33
+ args = parser.parse_args()
34
+
35
+ if not hasattr(args, "func"):
36
+ parser.print_help()
37
+ exit(1)
38
+
39
+ # Run
40
+ service = args.func(args)
41
+ service.run()
42
+
43
+
44
+ if __name__ == "__main__":
45
+ main()
pythonProject/.venv/Lib/site-packages/diffusers/commands/env.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ import subprocess
17
+ from argparse import ArgumentParser
18
+
19
+ import huggingface_hub
20
+
21
+ from .. import __version__ as version
22
+ from ..utils import (
23
+ is_accelerate_available,
24
+ is_bitsandbytes_available,
25
+ is_flax_available,
26
+ is_google_colab,
27
+ is_peft_available,
28
+ is_safetensors_available,
29
+ is_torch_available,
30
+ is_transformers_available,
31
+ is_xformers_available,
32
+ )
33
+ from . import BaseDiffusersCLICommand
34
+
35
+
36
+ def info_command_factory(_):
37
+ return EnvironmentCommand()
38
+
39
+
40
+ class EnvironmentCommand(BaseDiffusersCLICommand):
41
+ @staticmethod
42
+ def register_subcommand(parser: ArgumentParser) -> None:
43
+ download_parser = parser.add_parser("env")
44
+ download_parser.set_defaults(func=info_command_factory)
45
+
46
+ def run(self) -> dict:
47
+ hub_version = huggingface_hub.__version__
48
+
49
+ safetensors_version = "not installed"
50
+ if is_safetensors_available():
51
+ import safetensors
52
+
53
+ safetensors_version = safetensors.__version__
54
+
55
+ pt_version = "not installed"
56
+ pt_cuda_available = "NA"
57
+ if is_torch_available():
58
+ import torch
59
+
60
+ pt_version = torch.__version__
61
+ pt_cuda_available = torch.cuda.is_available()
62
+
63
+ flax_version = "not installed"
64
+ jax_version = "not installed"
65
+ jaxlib_version = "not installed"
66
+ jax_backend = "NA"
67
+ if is_flax_available():
68
+ import flax
69
+ import jax
70
+ import jaxlib
71
+
72
+ flax_version = flax.__version__
73
+ jax_version = jax.__version__
74
+ jaxlib_version = jaxlib.__version__
75
+ jax_backend = jax.lib.xla_bridge.get_backend().platform
76
+
77
+ transformers_version = "not installed"
78
+ if is_transformers_available():
79
+ import transformers
80
+
81
+ transformers_version = transformers.__version__
82
+
83
+ accelerate_version = "not installed"
84
+ if is_accelerate_available():
85
+ import accelerate
86
+
87
+ accelerate_version = accelerate.__version__
88
+
89
+ peft_version = "not installed"
90
+ if is_peft_available():
91
+ import peft
92
+
93
+ peft_version = peft.__version__
94
+
95
+ bitsandbytes_version = "not installed"
96
+ if is_bitsandbytes_available():
97
+ import bitsandbytes
98
+
99
+ bitsandbytes_version = bitsandbytes.__version__
100
+
101
+ xformers_version = "not installed"
102
+ if is_xformers_available():
103
+ import xformers
104
+
105
+ xformers_version = xformers.__version__
106
+
107
+ platform_info = platform.platform()
108
+
109
+ is_google_colab_str = "Yes" if is_google_colab() else "No"
110
+
111
+ accelerator = "NA"
112
+ if platform.system() in {"Linux", "Windows"}:
113
+ try:
114
+ sp = subprocess.Popen(
115
+ ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"],
116
+ stdout=subprocess.PIPE,
117
+ stderr=subprocess.PIPE,
118
+ )
119
+ out_str, _ = sp.communicate()
120
+ out_str = out_str.decode("utf-8")
121
+
122
+ if len(out_str) > 0:
123
+ accelerator = out_str.strip()
124
+ except FileNotFoundError:
125
+ pass
126
+ elif platform.system() == "Darwin": # Mac OS
127
+ try:
128
+ sp = subprocess.Popen(
129
+ ["system_profiler", "SPDisplaysDataType"],
130
+ stdout=subprocess.PIPE,
131
+ stderr=subprocess.PIPE,
132
+ )
133
+ out_str, _ = sp.communicate()
134
+ out_str = out_str.decode("utf-8")
135
+
136
+ start = out_str.find("Chipset Model:")
137
+ if start != -1:
138
+ start += len("Chipset Model:")
139
+ end = out_str.find("\n", start)
140
+ accelerator = out_str[start:end].strip()
141
+
142
+ start = out_str.find("VRAM (Total):")
143
+ if start != -1:
144
+ start += len("VRAM (Total):")
145
+ end = out_str.find("\n", start)
146
+ accelerator += " VRAM: " + out_str[start:end].strip()
147
+ except FileNotFoundError:
148
+ pass
149
+ else:
150
+ print("It seems you are running an unusual OS. Could you fill in the accelerator manually?")
151
+
152
+ info = {
153
+ "🤗 Diffusers version": version,
154
+ "Platform": platform_info,
155
+ "Running on Google Colab?": is_google_colab_str,
156
+ "Python version": platform.python_version(),
157
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
158
+ "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
159
+ "Jax version": jax_version,
160
+ "JaxLib version": jaxlib_version,
161
+ "Huggingface_hub version": hub_version,
162
+ "Transformers version": transformers_version,
163
+ "Accelerate version": accelerate_version,
164
+ "PEFT version": peft_version,
165
+ "Bitsandbytes version": bitsandbytes_version,
166
+ "Safetensors version": safetensors_version,
167
+ "xFormers version": xformers_version,
168
+ "Accelerator": accelerator,
169
+ "Using GPU in script?": "<fill in>",
170
+ "Using distributed or parallel set-up in script?": "<fill in>",
171
+ }
172
+
173
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
174
+ print(self.format_dict(info))
175
+
176
+ return info
177
+
178
+ @staticmethod
179
+ def format_dict(d: dict) -> str:
180
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
pythonProject/.venv/Lib/site-packages/diffusers/commands/fp16_safetensors.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
18
+ """
19
+
20
+ import glob
21
+ import json
22
+ import warnings
23
+ from argparse import ArgumentParser, Namespace
24
+ from importlib import import_module
25
+
26
+ import huggingface_hub
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+ from packaging import version
30
+
31
+ from ..utils import logging
32
+ from . import BaseDiffusersCLICommand
33
+
34
+
35
+ def conversion_command_factory(args: Namespace):
36
+ if args.use_auth_token:
37
+ warnings.warn(
38
+ "The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now"
39
+ " handled automatically if user is logged in."
40
+ )
41
+ return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
42
+
43
+
44
+ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
45
+ @staticmethod
46
+ def register_subcommand(parser: ArgumentParser):
47
+ conversion_parser = parser.add_parser("fp16_safetensors")
48
+ conversion_parser.add_argument(
49
+ "--ckpt_id",
50
+ type=str,
51
+ help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
52
+ )
53
+ conversion_parser.add_argument(
54
+ "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
55
+ )
56
+ conversion_parser.add_argument(
57
+ "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
58
+ )
59
+ conversion_parser.add_argument(
60
+ "--use_auth_token",
61
+ action="store_true",
62
+ help="When working with checkpoints having private visibility. When used `hf auth login` needs to be run beforehand.",
63
+ )
64
+ conversion_parser.set_defaults(func=conversion_command_factory)
65
+
66
+ def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
67
+ self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
68
+ self.ckpt_id = ckpt_id
69
+ self.local_ckpt_dir = f"/tmp/{ckpt_id}"
70
+ self.fp16 = fp16
71
+
72
+ self.use_safetensors = use_safetensors
73
+
74
+ if not self.use_safetensors and not self.fp16:
75
+ raise NotImplementedError(
76
+ "When `use_safetensors` and `fp16` both are False, then this command is of no use."
77
+ )
78
+
79
+ def run(self):
80
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
81
+ raise ImportError(
82
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
83
+ " installation."
84
+ )
85
+ else:
86
+ from huggingface_hub import create_commit
87
+ from huggingface_hub._commit_api import CommitOperationAdd
88
+
89
+ model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
90
+ with open(model_index, "r") as f:
91
+ pipeline_class_name = json.load(f)["_class_name"]
92
+ pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
93
+ self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
94
+
95
+ # Load the appropriate pipeline. We could have use `DiffusionPipeline`
96
+ # here, but just to avoid any rough edge cases.
97
+ pipeline = pipeline_class.from_pretrained(
98
+ self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
99
+ )
100
+ pipeline.save_pretrained(
101
+ self.local_ckpt_dir,
102
+ safe_serialization=True if self.use_safetensors else False,
103
+ variant="fp16" if self.fp16 else None,
104
+ )
105
+ self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
106
+
107
+ # Fetch all the paths.
108
+ if self.fp16:
109
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
110
+ elif self.use_safetensors:
111
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
112
+
113
+ # Prepare for the PR.
114
+ commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
115
+ operations = []
116
+ for path in modified_paths:
117
+ operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
118
+
119
+ # Open the PR.
120
+ commit_description = (
121
+ "Variables converted by the [`diffusers`' `fp16_safetensors`"
122
+ " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
123
+ )
124
+ hub_pr_url = create_commit(
125
+ repo_id=self.ckpt_id,
126
+ operations=operations,
127
+ commit_message=commit_message,
128
+ commit_description=commit_description,
129
+ repo_type="model",
130
+ create_pr=True,
131
+ ).pr_url
132
+ self.logger.info(f"PR created here: {hub_pr_url}.")
pythonProject/.venv/Lib/site-packages/diffusers/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
pythonProject/.venv/Lib/site-packages/diffusers/experimental/rl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .value_guided_sampling import ValueGuidedRLPipeline
pythonProject/.venv/Lib/site-packages/diffusers/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+
19
+ from ...models.unets.unet_1d import UNet1DModel
20
+ from ...pipelines import DiffusionPipeline
21
+ from ...utils.dummy_pt_objects import DDPMScheduler
22
+ from ...utils.torch_utils import randn_tensor
23
+
24
+
25
+ class ValueGuidedRLPipeline(DiffusionPipeline):
26
+ r"""
27
+ Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states.
28
+
29
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
30
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
31
+
32
+ Parameters:
33
+ value_function ([`UNet1DModel`]):
34
+ A specialized UNet for fine-tuning trajectories base on reward.
35
+ unet ([`UNet1DModel`]):
36
+ UNet architecture to denoise the encoded trajectories.
37
+ scheduler ([`SchedulerMixin`]):
38
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
39
+ application is [`DDPMScheduler`].
40
+ env ():
41
+ An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ value_function: UNet1DModel,
47
+ unet: UNet1DModel,
48
+ scheduler: DDPMScheduler,
49
+ env,
50
+ ):
51
+ super().__init__()
52
+
53
+ self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env)
54
+
55
+ self.data = env.get_dataset()
56
+ self.means = {}
57
+ for key in self.data.keys():
58
+ try:
59
+ self.means[key] = self.data[key].mean()
60
+ except: # noqa: E722
61
+ pass
62
+ self.stds = {}
63
+ for key in self.data.keys():
64
+ try:
65
+ self.stds[key] = self.data[key].std()
66
+ except: # noqa: E722
67
+ pass
68
+ self.state_dim = env.observation_space.shape[0]
69
+ self.action_dim = env.action_space.shape[0]
70
+
71
+ def normalize(self, x_in, key):
72
+ return (x_in - self.means[key]) / self.stds[key]
73
+
74
+ def de_normalize(self, x_in, key):
75
+ return x_in * self.stds[key] + self.means[key]
76
+
77
+ def to_torch(self, x_in):
78
+ if isinstance(x_in, dict):
79
+ return {k: self.to_torch(v) for k, v in x_in.items()}
80
+ elif torch.is_tensor(x_in):
81
+ return x_in.to(self.unet.device)
82
+ return torch.tensor(x_in, device=self.unet.device)
83
+
84
+ def reset_x0(self, x_in, cond, act_dim):
85
+ for key, val in cond.items():
86
+ x_in[:, key, act_dim:] = val.clone()
87
+ return x_in
88
+
89
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
90
+ batch_size = x.shape[0]
91
+ y = None
92
+ for i in tqdm.tqdm(self.scheduler.timesteps):
93
+ # create batch of timesteps to pass into model
94
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
95
+ for _ in range(n_guide_steps):
96
+ with torch.enable_grad():
97
+ x.requires_grad_()
98
+
99
+ # permute to match dimension for pre-trained models
100
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
101
+ grad = torch.autograd.grad([y.sum()], [x])[0]
102
+
103
+ posterior_variance = self.scheduler._get_variance(i)
104
+ model_std = torch.exp(0.5 * posterior_variance)
105
+ grad = model_std * grad
106
+
107
+ grad[timesteps < 2] = 0
108
+ x = x.detach()
109
+ x = x + scale * grad
110
+ x = self.reset_x0(x, conditions, self.action_dim)
111
+
112
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
113
+
114
+ # TODO: verify deprecation of this kwarg
115
+ x = self.scheduler.step(prev_x, i, x)["prev_sample"]
116
+
117
+ # apply conditions to the trajectory (set the initial state)
118
+ x = self.reset_x0(x, conditions, self.action_dim)
119
+ x = self.to_torch(x)
120
+ return x, y
121
+
122
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
123
+ # normalize the observations and create batch dimension
124
+ obs = self.normalize(obs, "observations")
125
+ obs = obs[None].repeat(batch_size, axis=0)
126
+
127
+ conditions = {0: self.to_torch(obs)}
128
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
129
+
130
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
131
+ x1 = randn_tensor(shape, device=self.unet.device)
132
+ x = self.reset_x0(x1, conditions, self.action_dim)
133
+ x = self.to_torch(x)
134
+
135
+ # run the diffusion process
136
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
137
+
138
+ # sort output trajectories by value
139
+ sorted_idx = y.argsort(0, descending=True).squeeze()
140
+ sorted_values = x[sorted_idx]
141
+ actions = sorted_values[:, :, : self.action_dim]
142
+ actions = actions.detach().cpu().numpy()
143
+ denorm_actions = self.de_normalize(actions, key="actions")
144
+
145
+ # select the action with the highest value
146
+ if y is not None:
147
+ selected_index = 0
148
+ else:
149
+ # if we didn't run value guiding, select a random action
150
+ selected_index = np.random.randint(0, batch_size)
151
+
152
+ denorm_actions = denorm_actions[selected_index, 0]
153
+ return denorm_actions
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_deis_multistep.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 FLAIR Lab and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: check https://huggingface.co/papers/2204.13902 and https://github.com/qsh-zh/deis for more info
16
+ # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ from ..configuration_utils import ConfigMixin, register_to_config
25
+ from ..utils import deprecate, is_scipy_available
26
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
+
28
+
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
32
+
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
34
+ def betas_for_alpha_bar(
35
+ num_diffusion_timesteps,
36
+ max_beta=0.999,
37
+ alpha_transform_type="cosine",
38
+ ):
39
+ """
40
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
41
+ (1-beta) over time from t = [0,1].
42
+
43
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
44
+ to that part of the diffusion process.
45
+
46
+
47
+ Args:
48
+ num_diffusion_timesteps (`int`): the number of betas to produce.
49
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
50
+ prevent singularities.
51
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
52
+ Choose from `cosine` or `exp`
53
+
54
+ Returns:
55
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
56
+ """
57
+ if alpha_transform_type == "cosine":
58
+
59
+ def alpha_bar_fn(t):
60
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
61
+
62
+ elif alpha_transform_type == "exp":
63
+
64
+ def alpha_bar_fn(t):
65
+ return math.exp(t * -12.0)
66
+
67
+ else:
68
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
69
+
70
+ betas = []
71
+ for i in range(num_diffusion_timesteps):
72
+ t1 = i / num_diffusion_timesteps
73
+ t2 = (i + 1) / num_diffusion_timesteps
74
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
75
+ return torch.tensor(betas, dtype=torch.float32)
76
+
77
+
78
+ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
79
+ """
80
+ `DEISMultistepScheduler` is a fast high order solver for diffusion ordinary differential equations (ODEs).
81
+
82
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
83
+ methods the library implements for all schedulers such as loading and saving.
84
+
85
+ Args:
86
+ num_train_timesteps (`int`, defaults to 1000):
87
+ The number of diffusion steps to train the model.
88
+ beta_start (`float`, defaults to 0.0001):
89
+ The starting `beta` value of inference.
90
+ beta_end (`float`, defaults to 0.02):
91
+ The final `beta` value.
92
+ beta_schedule (`str`, defaults to `"linear"`):
93
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
94
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
95
+ trained_betas (`np.ndarray`, *optional*):
96
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
97
+ solver_order (`int`, defaults to 2):
98
+ The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
99
+ sampling, and `solver_order=3` for unconditional sampling.
100
+ prediction_type (`str`, defaults to `epsilon`):
101
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
102
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
103
+ Video](https://imagen.research.google/video/paper.pdf) paper).
104
+ thresholding (`bool`, defaults to `False`):
105
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
106
+ as Stable Diffusion.
107
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
108
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
109
+ sample_max_value (`float`, defaults to 1.0):
110
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
111
+ algorithm_type (`str`, defaults to `deis`):
112
+ The algorithm type for the solver.
113
+ lower_order_final (`bool`, defaults to `True`):
114
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps.
115
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
116
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
117
+ the sigmas are determined according to a sequence of noise levels {σi}.
118
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
119
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
120
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
121
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
122
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
123
+ timestep_spacing (`str`, defaults to `"linspace"`):
124
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
125
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
126
+ steps_offset (`int`, defaults to 0):
127
+ An offset added to the inference steps, as required by some model families.
128
+ """
129
+
130
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
131
+ order = 1
132
+
133
+ @register_to_config
134
+ def __init__(
135
+ self,
136
+ num_train_timesteps: int = 1000,
137
+ beta_start: float = 0.0001,
138
+ beta_end: float = 0.02,
139
+ beta_schedule: str = "linear",
140
+ trained_betas: Optional[np.ndarray] = None,
141
+ solver_order: int = 2,
142
+ prediction_type: str = "epsilon",
143
+ thresholding: bool = False,
144
+ dynamic_thresholding_ratio: float = 0.995,
145
+ sample_max_value: float = 1.0,
146
+ algorithm_type: str = "deis",
147
+ solver_type: str = "logrho",
148
+ lower_order_final: bool = True,
149
+ use_karras_sigmas: Optional[bool] = False,
150
+ use_exponential_sigmas: Optional[bool] = False,
151
+ use_beta_sigmas: Optional[bool] = False,
152
+ use_flow_sigmas: Optional[bool] = False,
153
+ flow_shift: Optional[float] = 1.0,
154
+ timestep_spacing: str = "linspace",
155
+ steps_offset: int = 0,
156
+ use_dynamic_shifting: bool = False,
157
+ time_shift_type: str = "exponential",
158
+ ):
159
+ if self.config.use_beta_sigmas and not is_scipy_available():
160
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
161
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
162
+ raise ValueError(
163
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
164
+ )
165
+ if trained_betas is not None:
166
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
167
+ elif beta_schedule == "linear":
168
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
169
+ elif beta_schedule == "scaled_linear":
170
+ # this schedule is very specific to the latent diffusion model.
171
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
172
+ elif beta_schedule == "squaredcos_cap_v2":
173
+ # Glide cosine schedule
174
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
175
+ else:
176
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
177
+
178
+ self.alphas = 1.0 - self.betas
179
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
180
+ # Currently we only support VP-type noise schedule
181
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
182
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
183
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
184
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
185
+
186
+ # standard deviation of the initial noise distribution
187
+ self.init_noise_sigma = 1.0
188
+
189
+ # settings for DEIS
190
+ if algorithm_type not in ["deis"]:
191
+ if algorithm_type in ["dpmsolver", "dpmsolver++"]:
192
+ self.register_to_config(algorithm_type="deis")
193
+ else:
194
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
195
+
196
+ if solver_type not in ["logrho"]:
197
+ if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
198
+ self.register_to_config(solver_type="logrho")
199
+ else:
200
+ raise NotImplementedError(f"solver type {solver_type} is not implemented for {self.__class__}")
201
+
202
+ # setable values
203
+ self.num_inference_steps = None
204
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
205
+ self.timesteps = torch.from_numpy(timesteps)
206
+ self.model_outputs = [None] * solver_order
207
+ self.lower_order_nums = 0
208
+ self._step_index = None
209
+ self._begin_index = None
210
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
211
+
212
+ @property
213
+ def step_index(self):
214
+ """
215
+ The index counter for current timestep. It will increase 1 after each scheduler step.
216
+ """
217
+ return self._step_index
218
+
219
+ @property
220
+ def begin_index(self):
221
+ """
222
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
223
+ """
224
+ return self._begin_index
225
+
226
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
227
+ def set_begin_index(self, begin_index: int = 0):
228
+ """
229
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
230
+
231
+ Args:
232
+ begin_index (`int`):
233
+ The begin index for the scheduler.
234
+ """
235
+ self._begin_index = begin_index
236
+
237
+ def set_timesteps(
238
+ self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
239
+ ):
240
+ """
241
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
242
+
243
+ Args:
244
+ num_inference_steps (`int`):
245
+ The number of diffusion steps used when generating samples with a pre-trained model.
246
+ device (`str` or `torch.device`, *optional*):
247
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
248
+ """
249
+ if mu is not None:
250
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
251
+ self.config.flow_shift = np.exp(mu)
252
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
253
+ if self.config.timestep_spacing == "linspace":
254
+ timesteps = (
255
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
256
+ .round()[::-1][:-1]
257
+ .copy()
258
+ .astype(np.int64)
259
+ )
260
+ elif self.config.timestep_spacing == "leading":
261
+ step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
262
+ # creates integer timesteps by multiplying by ratio
263
+ # casting to int to avoid issues when num_inference_step is power of 3
264
+ timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
265
+ timesteps += self.config.steps_offset
266
+ elif self.config.timestep_spacing == "trailing":
267
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
268
+ # creates integer timesteps by multiplying by ratio
269
+ # casting to int to avoid issues when num_inference_step is power of 3
270
+ timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
271
+ timesteps -= 1
272
+ else:
273
+ raise ValueError(
274
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
275
+ )
276
+
277
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
278
+ log_sigmas = np.log(sigmas)
279
+ if self.config.use_karras_sigmas:
280
+ sigmas = np.flip(sigmas).copy()
281
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
282
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
283
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
284
+ elif self.config.use_exponential_sigmas:
285
+ sigmas = np.flip(sigmas).copy()
286
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
287
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
288
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
289
+ elif self.config.use_beta_sigmas:
290
+ sigmas = np.flip(sigmas).copy()
291
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
292
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
293
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
294
+ elif self.config.use_flow_sigmas:
295
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
296
+ sigmas = 1.0 - alphas
297
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
298
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
299
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
300
+ else:
301
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
302
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
303
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
304
+
305
+ self.sigmas = torch.from_numpy(sigmas)
306
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
307
+
308
+ self.num_inference_steps = len(timesteps)
309
+
310
+ self.model_outputs = [
311
+ None,
312
+ ] * self.config.solver_order
313
+ self.lower_order_nums = 0
314
+
315
+ # add an index counter for schedulers that allow duplicated timesteps
316
+ self._step_index = None
317
+ self._begin_index = None
318
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
319
+
320
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
321
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
322
+ """
323
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
324
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
325
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
326
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
327
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
328
+
329
+ https://huggingface.co/papers/2205.11487
330
+ """
331
+ dtype = sample.dtype
332
+ batch_size, channels, *remaining_dims = sample.shape
333
+
334
+ if dtype not in (torch.float32, torch.float64):
335
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
336
+
337
+ # Flatten sample for doing quantile calculation along each image
338
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
339
+
340
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
341
+
342
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
343
+ s = torch.clamp(
344
+ s, min=1, max=self.config.sample_max_value
345
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
346
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
347
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
348
+
349
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
350
+ sample = sample.to(dtype)
351
+
352
+ return sample
353
+
354
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
355
+ def _sigma_to_t(self, sigma, log_sigmas):
356
+ # get log sigma
357
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
358
+
359
+ # get distribution
360
+ dists = log_sigma - log_sigmas[:, np.newaxis]
361
+
362
+ # get sigmas range
363
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
364
+ high_idx = low_idx + 1
365
+
366
+ low = log_sigmas[low_idx]
367
+ high = log_sigmas[high_idx]
368
+
369
+ # interpolate sigmas
370
+ w = (low - log_sigma) / (low - high)
371
+ w = np.clip(w, 0, 1)
372
+
373
+ # transform interpolation to time range
374
+ t = (1 - w) * low_idx + w * high_idx
375
+ t = t.reshape(sigma.shape)
376
+ return t
377
+
378
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
379
+ def _sigma_to_alpha_sigma_t(self, sigma):
380
+ if self.config.use_flow_sigmas:
381
+ alpha_t = 1 - sigma
382
+ sigma_t = sigma
383
+ else:
384
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
385
+ sigma_t = sigma * alpha_t
386
+
387
+ return alpha_t, sigma_t
388
+
389
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
390
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
391
+ """Constructs the noise schedule of Karras et al. (2022)."""
392
+
393
+ # Hack to make sure that other schedulers which copy this function don't break
394
+ # TODO: Add this logic to the other schedulers
395
+ if hasattr(self.config, "sigma_min"):
396
+ sigma_min = self.config.sigma_min
397
+ else:
398
+ sigma_min = None
399
+
400
+ if hasattr(self.config, "sigma_max"):
401
+ sigma_max = self.config.sigma_max
402
+ else:
403
+ sigma_max = None
404
+
405
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
406
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
407
+
408
+ rho = 7.0 # 7.0 is the value used in the paper
409
+ ramp = np.linspace(0, 1, num_inference_steps)
410
+ min_inv_rho = sigma_min ** (1 / rho)
411
+ max_inv_rho = sigma_max ** (1 / rho)
412
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
413
+ return sigmas
414
+
415
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
416
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
417
+ """Constructs an exponential noise schedule."""
418
+
419
+ # Hack to make sure that other schedulers which copy this function don't break
420
+ # TODO: Add this logic to the other schedulers
421
+ if hasattr(self.config, "sigma_min"):
422
+ sigma_min = self.config.sigma_min
423
+ else:
424
+ sigma_min = None
425
+
426
+ if hasattr(self.config, "sigma_max"):
427
+ sigma_max = self.config.sigma_max
428
+ else:
429
+ sigma_max = None
430
+
431
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
432
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
433
+
434
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
435
+ return sigmas
436
+
437
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
438
+ def _convert_to_beta(
439
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
440
+ ) -> torch.Tensor:
441
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
442
+
443
+ # Hack to make sure that other schedulers which copy this function don't break
444
+ # TODO: Add this logic to the other schedulers
445
+ if hasattr(self.config, "sigma_min"):
446
+ sigma_min = self.config.sigma_min
447
+ else:
448
+ sigma_min = None
449
+
450
+ if hasattr(self.config, "sigma_max"):
451
+ sigma_max = self.config.sigma_max
452
+ else:
453
+ sigma_max = None
454
+
455
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
456
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
457
+
458
+ sigmas = np.array(
459
+ [
460
+ sigma_min + (ppf * (sigma_max - sigma_min))
461
+ for ppf in [
462
+ scipy.stats.beta.ppf(timestep, alpha, beta)
463
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
464
+ ]
465
+ ]
466
+ )
467
+ return sigmas
468
+
469
+ def convert_model_output(
470
+ self,
471
+ model_output: torch.Tensor,
472
+ *args,
473
+ sample: torch.Tensor = None,
474
+ **kwargs,
475
+ ) -> torch.Tensor:
476
+ """
477
+ Convert the model output to the corresponding type the DEIS algorithm needs.
478
+
479
+ Args:
480
+ model_output (`torch.Tensor`):
481
+ The direct output from the learned diffusion model.
482
+ timestep (`int`):
483
+ The current discrete timestep in the diffusion chain.
484
+ sample (`torch.Tensor`):
485
+ A current instance of a sample created by the diffusion process.
486
+
487
+ Returns:
488
+ `torch.Tensor`:
489
+ The converted model output.
490
+ """
491
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
492
+ if sample is None:
493
+ if len(args) > 1:
494
+ sample = args[1]
495
+ else:
496
+ raise ValueError("missing `sample` as a required keyword argument")
497
+ if timestep is not None:
498
+ deprecate(
499
+ "timesteps",
500
+ "1.0.0",
501
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
502
+ )
503
+
504
+ sigma = self.sigmas[self.step_index]
505
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
506
+ if self.config.prediction_type == "epsilon":
507
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
508
+ elif self.config.prediction_type == "sample":
509
+ x0_pred = model_output
510
+ elif self.config.prediction_type == "v_prediction":
511
+ x0_pred = alpha_t * sample - sigma_t * model_output
512
+ elif self.config.prediction_type == "flow_prediction":
513
+ sigma_t = self.sigmas[self.step_index]
514
+ x0_pred = sample - sigma_t * model_output
515
+ else:
516
+ raise ValueError(
517
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
518
+ "`v_prediction`, or `flow_prediction` for the DEISMultistepScheduler."
519
+ )
520
+
521
+ if self.config.thresholding:
522
+ x0_pred = self._threshold_sample(x0_pred)
523
+
524
+ if self.config.algorithm_type == "deis":
525
+ return (sample - alpha_t * x0_pred) / sigma_t
526
+ else:
527
+ raise NotImplementedError("only support log-rho multistep deis now")
528
+
529
+ def deis_first_order_update(
530
+ self,
531
+ model_output: torch.Tensor,
532
+ *args,
533
+ sample: torch.Tensor = None,
534
+ **kwargs,
535
+ ) -> torch.Tensor:
536
+ """
537
+ One step for the first-order DEIS (equivalent to DDIM).
538
+
539
+ Args:
540
+ model_output (`torch.Tensor`):
541
+ The direct output from the learned diffusion model.
542
+ timestep (`int`):
543
+ The current discrete timestep in the diffusion chain.
544
+ prev_timestep (`int`):
545
+ The previous discrete timestep in the diffusion chain.
546
+ sample (`torch.Tensor`):
547
+ A current instance of a sample created by the diffusion process.
548
+
549
+ Returns:
550
+ `torch.Tensor`:
551
+ The sample tensor at the previous timestep.
552
+ """
553
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
554
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
555
+ if sample is None:
556
+ if len(args) > 2:
557
+ sample = args[2]
558
+ else:
559
+ raise ValueError("missing `sample` as a required keyword argument")
560
+ if timestep is not None:
561
+ deprecate(
562
+ "timesteps",
563
+ "1.0.0",
564
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
565
+ )
566
+
567
+ if prev_timestep is not None:
568
+ deprecate(
569
+ "prev_timestep",
570
+ "1.0.0",
571
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
572
+ )
573
+
574
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
575
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
576
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
577
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
578
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
579
+
580
+ h = lambda_t - lambda_s
581
+ if self.config.algorithm_type == "deis":
582
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
583
+ else:
584
+ raise NotImplementedError("only support log-rho multistep deis now")
585
+ return x_t
586
+
587
+ def multistep_deis_second_order_update(
588
+ self,
589
+ model_output_list: List[torch.Tensor],
590
+ *args,
591
+ sample: torch.Tensor = None,
592
+ **kwargs,
593
+ ) -> torch.Tensor:
594
+ """
595
+ One step for the second-order multistep DEIS.
596
+
597
+ Args:
598
+ model_output_list (`List[torch.Tensor]`):
599
+ The direct outputs from learned diffusion model at current and latter timesteps.
600
+ sample (`torch.Tensor`):
601
+ A current instance of a sample created by the diffusion process.
602
+
603
+ Returns:
604
+ `torch.Tensor`:
605
+ The sample tensor at the previous timestep.
606
+ """
607
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
608
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
609
+ if sample is None:
610
+ if len(args) > 2:
611
+ sample = args[2]
612
+ else:
613
+ raise ValueError("missing `sample` as a required keyword argument")
614
+ if timestep_list is not None:
615
+ deprecate(
616
+ "timestep_list",
617
+ "1.0.0",
618
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
619
+ )
620
+
621
+ if prev_timestep is not None:
622
+ deprecate(
623
+ "prev_timestep",
624
+ "1.0.0",
625
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
626
+ )
627
+
628
+ sigma_t, sigma_s0, sigma_s1 = (
629
+ self.sigmas[self.step_index + 1],
630
+ self.sigmas[self.step_index],
631
+ self.sigmas[self.step_index - 1],
632
+ )
633
+
634
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
635
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
636
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
637
+
638
+ m0, m1 = model_output_list[-1], model_output_list[-2]
639
+
640
+ rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
641
+
642
+ if self.config.algorithm_type == "deis":
643
+
644
+ def ind_fn(t, b, c):
645
+ # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}]
646
+ return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c))
647
+
648
+ coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1)
649
+ coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0)
650
+
651
+ x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1)
652
+ return x_t
653
+ else:
654
+ raise NotImplementedError("only support log-rho multistep deis now")
655
+
656
+ def multistep_deis_third_order_update(
657
+ self,
658
+ model_output_list: List[torch.Tensor],
659
+ *args,
660
+ sample: torch.Tensor = None,
661
+ **kwargs,
662
+ ) -> torch.Tensor:
663
+ """
664
+ One step for the third-order multistep DEIS.
665
+
666
+ Args:
667
+ model_output_list (`List[torch.Tensor]`):
668
+ The direct outputs from learned diffusion model at current and latter timesteps.
669
+ sample (`torch.Tensor`):
670
+ A current instance of a sample created by diffusion process.
671
+
672
+ Returns:
673
+ `torch.Tensor`:
674
+ The sample tensor at the previous timestep.
675
+ """
676
+
677
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
678
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
679
+ if sample is None:
680
+ if len(args) > 2:
681
+ sample = args[2]
682
+ else:
683
+ raise ValueError("missing `sample` as a required keyword argument")
684
+ if timestep_list is not None:
685
+ deprecate(
686
+ "timestep_list",
687
+ "1.0.0",
688
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
689
+ )
690
+
691
+ if prev_timestep is not None:
692
+ deprecate(
693
+ "prev_timestep",
694
+ "1.0.0",
695
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
696
+ )
697
+
698
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
699
+ self.sigmas[self.step_index + 1],
700
+ self.sigmas[self.step_index],
701
+ self.sigmas[self.step_index - 1],
702
+ self.sigmas[self.step_index - 2],
703
+ )
704
+
705
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
706
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
707
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
708
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
709
+
710
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
711
+
712
+ rho_t, rho_s0, rho_s1, rho_s2 = (
713
+ sigma_t / alpha_t,
714
+ sigma_s0 / alpha_s0,
715
+ sigma_s1 / alpha_s1,
716
+ sigma_s2 / alpha_s2,
717
+ )
718
+
719
+ if self.config.algorithm_type == "deis":
720
+
721
+ def ind_fn(t, b, c, d):
722
+ # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}]
723
+ numerator = t * (
724
+ np.log(c) * (np.log(d) - np.log(t) + 1)
725
+ - np.log(d) * np.log(t)
726
+ + np.log(d)
727
+ + np.log(t) ** 2
728
+ - 2 * np.log(t)
729
+ + 2
730
+ )
731
+ denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d))
732
+ return numerator / denominator
733
+
734
+ coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2)
735
+ coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0)
736
+ coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1)
737
+
738
+ x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2)
739
+
740
+ return x_t
741
+ else:
742
+ raise NotImplementedError("only support log-rho multistep deis now")
743
+
744
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
745
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
746
+ if schedule_timesteps is None:
747
+ schedule_timesteps = self.timesteps
748
+
749
+ index_candidates = (schedule_timesteps == timestep).nonzero()
750
+
751
+ if len(index_candidates) == 0:
752
+ step_index = len(self.timesteps) - 1
753
+ # The sigma index that is taken for the **very** first `step`
754
+ # is always the second index (or the last index if there is only 1)
755
+ # This way we can ensure we don't accidentally skip a sigma in
756
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
757
+ elif len(index_candidates) > 1:
758
+ step_index = index_candidates[1].item()
759
+ else:
760
+ step_index = index_candidates[0].item()
761
+
762
+ return step_index
763
+
764
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
765
+ def _init_step_index(self, timestep):
766
+ """
767
+ Initialize the step_index counter for the scheduler.
768
+ """
769
+
770
+ if self.begin_index is None:
771
+ if isinstance(timestep, torch.Tensor):
772
+ timestep = timestep.to(self.timesteps.device)
773
+ self._step_index = self.index_for_timestep(timestep)
774
+ else:
775
+ self._step_index = self._begin_index
776
+
777
+ def step(
778
+ self,
779
+ model_output: torch.Tensor,
780
+ timestep: Union[int, torch.Tensor],
781
+ sample: torch.Tensor,
782
+ return_dict: bool = True,
783
+ ) -> Union[SchedulerOutput, Tuple]:
784
+ """
785
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
786
+ the multistep DEIS.
787
+
788
+ Args:
789
+ model_output (`torch.Tensor`):
790
+ The direct output from learned diffusion model.
791
+ timestep (`int`):
792
+ The current discrete timestep in the diffusion chain.
793
+ sample (`torch.Tensor`):
794
+ A current instance of a sample created by the diffusion process.
795
+ return_dict (`bool`):
796
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
797
+
798
+ Returns:
799
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
800
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
801
+ tuple is returned where the first element is the sample tensor.
802
+
803
+ """
804
+ if self.num_inference_steps is None:
805
+ raise ValueError(
806
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
807
+ )
808
+
809
+ if self.step_index is None:
810
+ self._init_step_index(timestep)
811
+
812
+ lower_order_final = (
813
+ (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
814
+ )
815
+ lower_order_second = (
816
+ (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
817
+ )
818
+
819
+ model_output = self.convert_model_output(model_output, sample=sample)
820
+ for i in range(self.config.solver_order - 1):
821
+ self.model_outputs[i] = self.model_outputs[i + 1]
822
+ self.model_outputs[-1] = model_output
823
+
824
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
825
+ prev_sample = self.deis_first_order_update(model_output, sample=sample)
826
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
827
+ prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample)
828
+ else:
829
+ prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample)
830
+
831
+ if self.lower_order_nums < self.config.solver_order:
832
+ self.lower_order_nums += 1
833
+
834
+ # upon completion increase step index by one
835
+ self._step_index += 1
836
+
837
+ if not return_dict:
838
+ return (prev_sample,)
839
+
840
+ return SchedulerOutput(prev_sample=prev_sample)
841
+
842
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
843
+ """
844
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
845
+ current timestep.
846
+
847
+ Args:
848
+ sample (`torch.Tensor`):
849
+ The input sample.
850
+
851
+ Returns:
852
+ `torch.Tensor`:
853
+ A scaled input sample.
854
+ """
855
+ return sample
856
+
857
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
858
+ def add_noise(
859
+ self,
860
+ original_samples: torch.Tensor,
861
+ noise: torch.Tensor,
862
+ timesteps: torch.IntTensor,
863
+ ) -> torch.Tensor:
864
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
865
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
866
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
867
+ # mps does not support float64
868
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
869
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
870
+ else:
871
+ schedule_timesteps = self.timesteps.to(original_samples.device)
872
+ timesteps = timesteps.to(original_samples.device)
873
+
874
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
875
+ if self.begin_index is None:
876
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
877
+ elif self.step_index is not None:
878
+ # add_noise is called after first denoising step (for inpainting)
879
+ step_indices = [self.step_index] * timesteps.shape[0]
880
+ else:
881
+ # add noise is called before first denoising step to create initial latent(img2img)
882
+ step_indices = [self.begin_index] * timesteps.shape[0]
883
+
884
+ sigma = sigmas[step_indices].flatten()
885
+ while len(sigma.shape) < len(original_samples.shape):
886
+ sigma = sigma.unsqueeze(-1)
887
+
888
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
889
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
890
+ return noisy_samples
891
+
892
+ def __len__(self):
893
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep.py ADDED
@@ -0,0 +1,1180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from ..configuration_utils import ConfigMixin, register_to_config
24
+ from ..utils import deprecate, is_scipy_available
25
+ from ..utils.torch_utils import randn_tensor
26
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
+
28
+
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
32
+
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
34
+ def betas_for_alpha_bar(
35
+ num_diffusion_timesteps,
36
+ max_beta=0.999,
37
+ alpha_transform_type="cosine",
38
+ ):
39
+ """
40
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
41
+ (1-beta) over time from t = [0,1].
42
+
43
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
44
+ to that part of the diffusion process.
45
+
46
+
47
+ Args:
48
+ num_diffusion_timesteps (`int`): the number of betas to produce.
49
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
50
+ prevent singularities.
51
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
52
+ Choose from `cosine` or `exp`
53
+
54
+ Returns:
55
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
56
+ """
57
+ if alpha_transform_type == "cosine":
58
+
59
+ def alpha_bar_fn(t):
60
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
61
+
62
+ elif alpha_transform_type == "exp":
63
+
64
+ def alpha_bar_fn(t):
65
+ return math.exp(t * -12.0)
66
+
67
+ else:
68
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
69
+
70
+ betas = []
71
+ for i in range(num_diffusion_timesteps):
72
+ t1 = i / num_diffusion_timesteps
73
+ t2 = (i + 1) / num_diffusion_timesteps
74
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
75
+ return torch.tensor(betas, dtype=torch.float32)
76
+
77
+
78
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
79
+ def rescale_zero_terminal_snr(betas):
80
+ """
81
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
82
+
83
+
84
+ Args:
85
+ betas (`torch.Tensor`):
86
+ the betas that the scheduler is being initialized with.
87
+
88
+ Returns:
89
+ `torch.Tensor`: rescaled betas with zero terminal SNR
90
+ """
91
+ # Convert betas to alphas_bar_sqrt
92
+ alphas = 1.0 - betas
93
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
94
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
95
+
96
+ # Store old values.
97
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
98
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
99
+
100
+ # Shift so the last timestep is zero.
101
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
102
+
103
+ # Scale so the first timestep is back to the old value.
104
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
105
+
106
+ # Convert alphas_bar_sqrt to betas
107
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
108
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
109
+ alphas = torch.cat([alphas_bar[0:1], alphas])
110
+ betas = 1 - alphas
111
+
112
+ return betas
113
+
114
+
115
+ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
116
+ """
117
+ `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
118
+
119
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
120
+ methods the library implements for all schedulers such as loading and saving.
121
+
122
+ Args:
123
+ num_train_timesteps (`int`, defaults to 1000):
124
+ The number of diffusion steps to train the model.
125
+ beta_start (`float`, defaults to 0.0001):
126
+ The starting `beta` value of inference.
127
+ beta_end (`float`, defaults to 0.02):
128
+ The final `beta` value.
129
+ beta_schedule (`str`, defaults to `"linear"`):
130
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
131
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
132
+ trained_betas (`np.ndarray`, *optional*):
133
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
134
+ solver_order (`int`, defaults to 2):
135
+ The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
136
+ sampling, and `solver_order=3` for unconditional sampling.
137
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
138
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
139
+ `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
140
+ Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
141
+ thresholding (`bool`, defaults to `False`):
142
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
143
+ as Stable Diffusion.
144
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
145
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
146
+ sample_max_value (`float`, defaults to 1.0):
147
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
148
+ `algorithm_type="dpmsolver++"`.
149
+ algorithm_type (`str`, defaults to `dpmsolver++`):
150
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
151
+ `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
152
+ paper, and the `dpmsolver++` type implements the algorithms in the
153
+ [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
154
+ `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
155
+ solver_type (`str`, defaults to `midpoint`):
156
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
157
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
158
+ lower_order_final (`bool`, defaults to `True`):
159
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
160
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
161
+ euler_at_final (`bool`, defaults to `False`):
162
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
163
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
164
+ steps, but sometimes may result in blurring.
165
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
166
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
167
+ the sigmas are determined according to a sequence of noise levels {σi}.
168
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
169
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
170
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
171
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
172
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
173
+ use_lu_lambdas (`bool`, *optional*, defaults to `False`):
174
+ Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
175
+ the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
176
+ `lambda(t)`.
177
+ use_flow_sigmas (`bool`, *optional*, defaults to `False`):
178
+ Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
179
+ flow_shift (`float`, *optional*, defaults to 1.0):
180
+ The shift value for the timestep schedule for flow matching.
181
+ final_sigmas_type (`str`, defaults to `"zero"`):
182
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
183
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
184
+ lambda_min_clipped (`float`, defaults to `-inf`):
185
+ Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
186
+ cosine (`squaredcos_cap_v2`) noise schedule.
187
+ variance_type (`str`, *optional*):
188
+ Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
189
+ contains the predicted Gaussian variance.
190
+ timestep_spacing (`str`, defaults to `"linspace"`):
191
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
192
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
193
+ steps_offset (`int`, defaults to 0):
194
+ An offset added to the inference steps, as required by some model families.
195
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
196
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
197
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
198
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
199
+ """
200
+
201
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
202
+ order = 1
203
+
204
+ @register_to_config
205
+ def __init__(
206
+ self,
207
+ num_train_timesteps: int = 1000,
208
+ beta_start: float = 0.0001,
209
+ beta_end: float = 0.02,
210
+ beta_schedule: str = "linear",
211
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
212
+ solver_order: int = 2,
213
+ prediction_type: str = "epsilon",
214
+ thresholding: bool = False,
215
+ dynamic_thresholding_ratio: float = 0.995,
216
+ sample_max_value: float = 1.0,
217
+ algorithm_type: str = "dpmsolver++",
218
+ solver_type: str = "midpoint",
219
+ lower_order_final: bool = True,
220
+ euler_at_final: bool = False,
221
+ use_karras_sigmas: Optional[bool] = False,
222
+ use_exponential_sigmas: Optional[bool] = False,
223
+ use_beta_sigmas: Optional[bool] = False,
224
+ use_lu_lambdas: Optional[bool] = False,
225
+ use_flow_sigmas: Optional[bool] = False,
226
+ flow_shift: Optional[float] = 1.0,
227
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
228
+ lambda_min_clipped: float = -float("inf"),
229
+ variance_type: Optional[str] = None,
230
+ timestep_spacing: str = "linspace",
231
+ steps_offset: int = 0,
232
+ rescale_betas_zero_snr: bool = False,
233
+ use_dynamic_shifting: bool = False,
234
+ time_shift_type: str = "exponential",
235
+ ):
236
+ if self.config.use_beta_sigmas and not is_scipy_available():
237
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
238
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
239
+ raise ValueError(
240
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
241
+ )
242
+ if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
243
+ deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
244
+ deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
245
+
246
+ if trained_betas is not None:
247
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
248
+ elif beta_schedule == "linear":
249
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
250
+ elif beta_schedule == "scaled_linear":
251
+ # this schedule is very specific to the latent diffusion model.
252
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
253
+ elif beta_schedule == "squaredcos_cap_v2":
254
+ # Glide cosine schedule
255
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
256
+ else:
257
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
258
+
259
+ if rescale_betas_zero_snr:
260
+ self.betas = rescale_zero_terminal_snr(self.betas)
261
+
262
+ self.alphas = 1.0 - self.betas
263
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
264
+
265
+ if rescale_betas_zero_snr:
266
+ # Close to 0 without being 0 so first sigma is not inf
267
+ # FP16 smallest positive subnormal works well here
268
+ self.alphas_cumprod[-1] = 2**-24
269
+
270
+ # Currently we only support VP-type noise schedule
271
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
272
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
273
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
274
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
275
+
276
+ # standard deviation of the initial noise distribution
277
+ self.init_noise_sigma = 1.0
278
+
279
+ # settings for DPM-Solver
280
+ if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
281
+ if algorithm_type == "deis":
282
+ self.register_to_config(algorithm_type="dpmsolver++")
283
+ else:
284
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
285
+
286
+ if solver_type not in ["midpoint", "heun"]:
287
+ if solver_type in ["logrho", "bh1", "bh2"]:
288
+ self.register_to_config(solver_type="midpoint")
289
+ else:
290
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
291
+
292
+ if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
293
+ raise ValueError(
294
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
295
+ )
296
+
297
+ # setable values
298
+ self.num_inference_steps = None
299
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
300
+ self.timesteps = torch.from_numpy(timesteps)
301
+ self.model_outputs = [None] * solver_order
302
+ self.lower_order_nums = 0
303
+ self._step_index = None
304
+ self._begin_index = None
305
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
306
+
307
+ @property
308
+ def step_index(self):
309
+ """
310
+ The index counter for current timestep. It will increase 1 after each scheduler step.
311
+ """
312
+ return self._step_index
313
+
314
+ @property
315
+ def begin_index(self):
316
+ """
317
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
318
+ """
319
+ return self._begin_index
320
+
321
+ def set_begin_index(self, begin_index: int = 0):
322
+ """
323
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
324
+
325
+ Args:
326
+ begin_index (`int`):
327
+ The begin index for the scheduler.
328
+ """
329
+ self._begin_index = begin_index
330
+
331
+ def set_timesteps(
332
+ self,
333
+ num_inference_steps: int = None,
334
+ device: Union[str, torch.device] = None,
335
+ mu: Optional[float] = None,
336
+ timesteps: Optional[List[int]] = None,
337
+ ):
338
+ """
339
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
340
+
341
+ Args:
342
+ num_inference_steps (`int`):
343
+ The number of diffusion steps used when generating samples with a pre-trained model.
344
+ device (`str` or `torch.device`, *optional*):
345
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
346
+ timesteps (`List[int]`, *optional*):
347
+ Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
348
+ based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
349
+ must be `None`, and `timestep_spacing` attribute will be ignored.
350
+ """
351
+ if mu is not None:
352
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
353
+ self.config.flow_shift = np.exp(mu)
354
+ if num_inference_steps is None and timesteps is None:
355
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
356
+ if num_inference_steps is not None and timesteps is not None:
357
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
358
+ if timesteps is not None and self.config.use_karras_sigmas:
359
+ raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
360
+ if timesteps is not None and self.config.use_lu_lambdas:
361
+ raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
362
+ if timesteps is not None and self.config.use_exponential_sigmas:
363
+ raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
364
+ if timesteps is not None and self.config.use_beta_sigmas:
365
+ raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
366
+
367
+ if timesteps is not None:
368
+ timesteps = np.array(timesteps).astype(np.int64)
369
+ else:
370
+ # Clipping the minimum of all lambda(t) for numerical stability.
371
+ # This is critical for cosine (squaredcos_cap_v2) noise schedule.
372
+ clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
373
+ last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
374
+
375
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
376
+ if self.config.timestep_spacing == "linspace":
377
+ timesteps = (
378
+ np.linspace(0, last_timestep - 1, num_inference_steps + 1)
379
+ .round()[::-1][:-1]
380
+ .copy()
381
+ .astype(np.int64)
382
+ )
383
+ elif self.config.timestep_spacing == "leading":
384
+ step_ratio = last_timestep // (num_inference_steps + 1)
385
+ # creates integer timesteps by multiplying by ratio
386
+ # casting to int to avoid issues when num_inference_step is power of 3
387
+ timesteps = (
388
+ (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
389
+ )
390
+ timesteps += self.config.steps_offset
391
+ elif self.config.timestep_spacing == "trailing":
392
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
393
+ # creates integer timesteps by multiplying by ratio
394
+ # casting to int to avoid issues when num_inference_step is power of 3
395
+ timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
396
+ timesteps -= 1
397
+ else:
398
+ raise ValueError(
399
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
400
+ )
401
+
402
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
403
+ log_sigmas = np.log(sigmas)
404
+
405
+ if self.config.use_karras_sigmas:
406
+ sigmas = np.flip(sigmas).copy()
407
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
408
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
409
+ if self.config.beta_schedule != "squaredcos_cap_v2":
410
+ timesteps = timesteps.round()
411
+ elif self.config.use_lu_lambdas:
412
+ lambdas = np.flip(log_sigmas.copy())
413
+ lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
414
+ sigmas = np.exp(lambdas)
415
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
416
+ if self.config.beta_schedule != "squaredcos_cap_v2":
417
+ timesteps = timesteps.round()
418
+ elif self.config.use_exponential_sigmas:
419
+ sigmas = np.flip(sigmas).copy()
420
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
421
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
422
+ elif self.config.use_beta_sigmas:
423
+ sigmas = np.flip(sigmas).copy()
424
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
425
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
426
+ elif self.config.use_flow_sigmas:
427
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
428
+ sigmas = 1.0 - alphas
429
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
430
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
431
+ else:
432
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
433
+
434
+ if self.config.final_sigmas_type == "sigma_min":
435
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
436
+ elif self.config.final_sigmas_type == "zero":
437
+ sigma_last = 0
438
+ else:
439
+ raise ValueError(
440
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
441
+ )
442
+
443
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
444
+
445
+ self.sigmas = torch.from_numpy(sigmas)
446
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
447
+
448
+ self.num_inference_steps = len(timesteps)
449
+
450
+ self.model_outputs = [
451
+ None,
452
+ ] * self.config.solver_order
453
+ self.lower_order_nums = 0
454
+
455
+ # add an index counter for schedulers that allow duplicated timesteps
456
+ self._step_index = None
457
+ self._begin_index = None
458
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
459
+
460
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
461
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
462
+ """
463
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
464
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
465
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
466
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
467
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
468
+
469
+ https://huggingface.co/papers/2205.11487
470
+ """
471
+ dtype = sample.dtype
472
+ batch_size, channels, *remaining_dims = sample.shape
473
+
474
+ if dtype not in (torch.float32, torch.float64):
475
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
476
+
477
+ # Flatten sample for doing quantile calculation along each image
478
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
479
+
480
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
481
+
482
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
483
+ s = torch.clamp(
484
+ s, min=1, max=self.config.sample_max_value
485
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
486
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
487
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
488
+
489
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
490
+ sample = sample.to(dtype)
491
+
492
+ return sample
493
+
494
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
495
+ def _sigma_to_t(self, sigma, log_sigmas):
496
+ # get log sigma
497
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
498
+
499
+ # get distribution
500
+ dists = log_sigma - log_sigmas[:, np.newaxis]
501
+
502
+ # get sigmas range
503
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
504
+ high_idx = low_idx + 1
505
+
506
+ low = log_sigmas[low_idx]
507
+ high = log_sigmas[high_idx]
508
+
509
+ # interpolate sigmas
510
+ w = (low - log_sigma) / (low - high)
511
+ w = np.clip(w, 0, 1)
512
+
513
+ # transform interpolation to time range
514
+ t = (1 - w) * low_idx + w * high_idx
515
+ t = t.reshape(sigma.shape)
516
+ return t
517
+
518
+ def _sigma_to_alpha_sigma_t(self, sigma):
519
+ if self.config.use_flow_sigmas:
520
+ alpha_t = 1 - sigma
521
+ sigma_t = sigma
522
+ else:
523
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
524
+ sigma_t = sigma * alpha_t
525
+
526
+ return alpha_t, sigma_t
527
+
528
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
529
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
530
+ """Constructs the noise schedule of Karras et al. (2022)."""
531
+
532
+ # Hack to make sure that other schedulers which copy this function don't break
533
+ # TODO: Add this logic to the other schedulers
534
+ if hasattr(self.config, "sigma_min"):
535
+ sigma_min = self.config.sigma_min
536
+ else:
537
+ sigma_min = None
538
+
539
+ if hasattr(self.config, "sigma_max"):
540
+ sigma_max = self.config.sigma_max
541
+ else:
542
+ sigma_max = None
543
+
544
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
545
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
546
+
547
+ rho = 7.0 # 7.0 is the value used in the paper
548
+ ramp = np.linspace(0, 1, num_inference_steps)
549
+ min_inv_rho = sigma_min ** (1 / rho)
550
+ max_inv_rho = sigma_max ** (1 / rho)
551
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
552
+ return sigmas
553
+
554
+ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
555
+ """Constructs the noise schedule of Lu et al. (2022)."""
556
+
557
+ lambda_min: float = in_lambdas[-1].item()
558
+ lambda_max: float = in_lambdas[0].item()
559
+
560
+ rho = 1.0 # 1.0 is the value used in the paper
561
+ ramp = np.linspace(0, 1, num_inference_steps)
562
+ min_inv_rho = lambda_min ** (1 / rho)
563
+ max_inv_rho = lambda_max ** (1 / rho)
564
+ lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
565
+ return lambdas
566
+
567
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
568
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
569
+ """Constructs an exponential noise schedule."""
570
+
571
+ # Hack to make sure that other schedulers which copy this function don't break
572
+ # TODO: Add this logic to the other schedulers
573
+ if hasattr(self.config, "sigma_min"):
574
+ sigma_min = self.config.sigma_min
575
+ else:
576
+ sigma_min = None
577
+
578
+ if hasattr(self.config, "sigma_max"):
579
+ sigma_max = self.config.sigma_max
580
+ else:
581
+ sigma_max = None
582
+
583
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
584
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
585
+
586
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
587
+ return sigmas
588
+
589
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
590
+ def _convert_to_beta(
591
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
592
+ ) -> torch.Tensor:
593
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
594
+
595
+ # Hack to make sure that other schedulers which copy this function don't break
596
+ # TODO: Add this logic to the other schedulers
597
+ if hasattr(self.config, "sigma_min"):
598
+ sigma_min = self.config.sigma_min
599
+ else:
600
+ sigma_min = None
601
+
602
+ if hasattr(self.config, "sigma_max"):
603
+ sigma_max = self.config.sigma_max
604
+ else:
605
+ sigma_max = None
606
+
607
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
608
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
609
+
610
+ sigmas = np.array(
611
+ [
612
+ sigma_min + (ppf * (sigma_max - sigma_min))
613
+ for ppf in [
614
+ scipy.stats.beta.ppf(timestep, alpha, beta)
615
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
616
+ ]
617
+ ]
618
+ )
619
+ return sigmas
620
+
621
+ def convert_model_output(
622
+ self,
623
+ model_output: torch.Tensor,
624
+ *args,
625
+ sample: torch.Tensor = None,
626
+ **kwargs,
627
+ ) -> torch.Tensor:
628
+ """
629
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
630
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
631
+ integral of the data prediction model.
632
+
633
+ <Tip>
634
+
635
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
636
+ prediction and data prediction models.
637
+
638
+ </Tip>
639
+
640
+ Args:
641
+ model_output (`torch.Tensor`):
642
+ The direct output from the learned diffusion model.
643
+ sample (`torch.Tensor`):
644
+ A current instance of a sample created by the diffusion process.
645
+
646
+ Returns:
647
+ `torch.Tensor`:
648
+ The converted model output.
649
+ """
650
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
651
+ if sample is None:
652
+ if len(args) > 1:
653
+ sample = args[1]
654
+ else:
655
+ raise ValueError("missing `sample` as a required keyword argument")
656
+ if timestep is not None:
657
+ deprecate(
658
+ "timesteps",
659
+ "1.0.0",
660
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
661
+ )
662
+
663
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
664
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
665
+ if self.config.prediction_type == "epsilon":
666
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
667
+ if self.config.variance_type in ["learned", "learned_range"]:
668
+ model_output = model_output[:, :3]
669
+ sigma = self.sigmas[self.step_index]
670
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
671
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
672
+ elif self.config.prediction_type == "sample":
673
+ x0_pred = model_output
674
+ elif self.config.prediction_type == "v_prediction":
675
+ sigma = self.sigmas[self.step_index]
676
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
677
+ x0_pred = alpha_t * sample - sigma_t * model_output
678
+ elif self.config.prediction_type == "flow_prediction":
679
+ sigma_t = self.sigmas[self.step_index]
680
+ x0_pred = sample - sigma_t * model_output
681
+ else:
682
+ raise ValueError(
683
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
684
+ "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
685
+ )
686
+
687
+ if self.config.thresholding:
688
+ x0_pred = self._threshold_sample(x0_pred)
689
+
690
+ return x0_pred
691
+
692
+ # DPM-Solver needs to solve an integral of the noise prediction model.
693
+ elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
694
+ if self.config.prediction_type == "epsilon":
695
+ # DPM-Solver and DPM-Solver++ only need the "mean" output.
696
+ if self.config.variance_type in ["learned", "learned_range"]:
697
+ epsilon = model_output[:, :3]
698
+ else:
699
+ epsilon = model_output
700
+ elif self.config.prediction_type == "sample":
701
+ sigma = self.sigmas[self.step_index]
702
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
703
+ epsilon = (sample - alpha_t * model_output) / sigma_t
704
+ elif self.config.prediction_type == "v_prediction":
705
+ sigma = self.sigmas[self.step_index]
706
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
707
+ epsilon = alpha_t * model_output + sigma_t * sample
708
+ else:
709
+ raise ValueError(
710
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
711
+ " `v_prediction` for the DPMSolverMultistepScheduler."
712
+ )
713
+
714
+ if self.config.thresholding:
715
+ sigma = self.sigmas[self.step_index]
716
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
717
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
718
+ x0_pred = self._threshold_sample(x0_pred)
719
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
720
+
721
+ return epsilon
722
+
723
+ def dpm_solver_first_order_update(
724
+ self,
725
+ model_output: torch.Tensor,
726
+ *args,
727
+ sample: torch.Tensor = None,
728
+ noise: Optional[torch.Tensor] = None,
729
+ **kwargs,
730
+ ) -> torch.Tensor:
731
+ """
732
+ One step for the first-order DPMSolver (equivalent to DDIM).
733
+
734
+ Args:
735
+ model_output (`torch.Tensor`):
736
+ The direct output from the learned diffusion model.
737
+ sample (`torch.Tensor`):
738
+ A current instance of a sample created by the diffusion process.
739
+
740
+ Returns:
741
+ `torch.Tensor`:
742
+ The sample tensor at the previous timestep.
743
+ """
744
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
745
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
746
+ if sample is None:
747
+ if len(args) > 2:
748
+ sample = args[2]
749
+ else:
750
+ raise ValueError("missing `sample` as a required keyword argument")
751
+ if timestep is not None:
752
+ deprecate(
753
+ "timesteps",
754
+ "1.0.0",
755
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
756
+ )
757
+
758
+ if prev_timestep is not None:
759
+ deprecate(
760
+ "prev_timestep",
761
+ "1.0.0",
762
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
763
+ )
764
+
765
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
766
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
767
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
768
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
769
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
770
+
771
+ h = lambda_t - lambda_s
772
+ if self.config.algorithm_type == "dpmsolver++":
773
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
774
+ elif self.config.algorithm_type == "dpmsolver":
775
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
776
+ elif self.config.algorithm_type == "sde-dpmsolver++":
777
+ assert noise is not None
778
+ x_t = (
779
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
780
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
781
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
782
+ )
783
+ elif self.config.algorithm_type == "sde-dpmsolver":
784
+ assert noise is not None
785
+ x_t = (
786
+ (alpha_t / alpha_s) * sample
787
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
788
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
789
+ )
790
+ return x_t
791
+
792
+ def multistep_dpm_solver_second_order_update(
793
+ self,
794
+ model_output_list: List[torch.Tensor],
795
+ *args,
796
+ sample: torch.Tensor = None,
797
+ noise: Optional[torch.Tensor] = None,
798
+ **kwargs,
799
+ ) -> torch.Tensor:
800
+ """
801
+ One step for the second-order multistep DPMSolver.
802
+
803
+ Args:
804
+ model_output_list (`List[torch.Tensor]`):
805
+ The direct outputs from learned diffusion model at current and latter timesteps.
806
+ sample (`torch.Tensor`):
807
+ A current instance of a sample created by the diffusion process.
808
+
809
+ Returns:
810
+ `torch.Tensor`:
811
+ The sample tensor at the previous timestep.
812
+ """
813
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
814
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
815
+ if sample is None:
816
+ if len(args) > 2:
817
+ sample = args[2]
818
+ else:
819
+ raise ValueError("missing `sample` as a required keyword argument")
820
+ if timestep_list is not None:
821
+ deprecate(
822
+ "timestep_list",
823
+ "1.0.0",
824
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
825
+ )
826
+
827
+ if prev_timestep is not None:
828
+ deprecate(
829
+ "prev_timestep",
830
+ "1.0.0",
831
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
832
+ )
833
+
834
+ sigma_t, sigma_s0, sigma_s1 = (
835
+ self.sigmas[self.step_index + 1],
836
+ self.sigmas[self.step_index],
837
+ self.sigmas[self.step_index - 1],
838
+ )
839
+
840
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
841
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
842
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
843
+
844
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
845
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
846
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
847
+
848
+ m0, m1 = model_output_list[-1], model_output_list[-2]
849
+
850
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
851
+ r0 = h_0 / h
852
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
853
+ if self.config.algorithm_type == "dpmsolver++":
854
+ # See https://huggingface.co/papers/2211.01095 for detailed derivations
855
+ if self.config.solver_type == "midpoint":
856
+ x_t = (
857
+ (sigma_t / sigma_s0) * sample
858
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
859
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
860
+ )
861
+ elif self.config.solver_type == "heun":
862
+ x_t = (
863
+ (sigma_t / sigma_s0) * sample
864
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
865
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
866
+ )
867
+ elif self.config.algorithm_type == "dpmsolver":
868
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
869
+ if self.config.solver_type == "midpoint":
870
+ x_t = (
871
+ (alpha_t / alpha_s0) * sample
872
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
873
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1
874
+ )
875
+ elif self.config.solver_type == "heun":
876
+ x_t = (
877
+ (alpha_t / alpha_s0) * sample
878
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
879
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
880
+ )
881
+ elif self.config.algorithm_type == "sde-dpmsolver++":
882
+ assert noise is not None
883
+ if self.config.solver_type == "midpoint":
884
+ x_t = (
885
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
886
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
887
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
888
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
889
+ )
890
+ elif self.config.solver_type == "heun":
891
+ x_t = (
892
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
893
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
894
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
895
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
896
+ )
897
+ elif self.config.algorithm_type == "sde-dpmsolver":
898
+ assert noise is not None
899
+ if self.config.solver_type == "midpoint":
900
+ x_t = (
901
+ (alpha_t / alpha_s0) * sample
902
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
903
+ - (sigma_t * (torch.exp(h) - 1.0)) * D1
904
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
905
+ )
906
+ elif self.config.solver_type == "heun":
907
+ x_t = (
908
+ (alpha_t / alpha_s0) * sample
909
+ - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
910
+ - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
911
+ + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
912
+ )
913
+ return x_t
914
+
915
+ def multistep_dpm_solver_third_order_update(
916
+ self,
917
+ model_output_list: List[torch.Tensor],
918
+ *args,
919
+ sample: torch.Tensor = None,
920
+ noise: Optional[torch.Tensor] = None,
921
+ **kwargs,
922
+ ) -> torch.Tensor:
923
+ """
924
+ One step for the third-order multistep DPMSolver.
925
+
926
+ Args:
927
+ model_output_list (`List[torch.Tensor]`):
928
+ The direct outputs from learned diffusion model at current and latter timesteps.
929
+ sample (`torch.Tensor`):
930
+ A current instance of a sample created by diffusion process.
931
+
932
+ Returns:
933
+ `torch.Tensor`:
934
+ The sample tensor at the previous timestep.
935
+ """
936
+
937
+ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
938
+ prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
939
+ if sample is None:
940
+ if len(args) > 2:
941
+ sample = args[2]
942
+ else:
943
+ raise ValueError("missing `sample` as a required keyword argument")
944
+ if timestep_list is not None:
945
+ deprecate(
946
+ "timestep_list",
947
+ "1.0.0",
948
+ "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
949
+ )
950
+
951
+ if prev_timestep is not None:
952
+ deprecate(
953
+ "prev_timestep",
954
+ "1.0.0",
955
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
956
+ )
957
+
958
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
959
+ self.sigmas[self.step_index + 1],
960
+ self.sigmas[self.step_index],
961
+ self.sigmas[self.step_index - 1],
962
+ self.sigmas[self.step_index - 2],
963
+ )
964
+
965
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
966
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
967
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
968
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
969
+
970
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
971
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
972
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
973
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
974
+
975
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
976
+
977
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
978
+ r0, r1 = h_0 / h, h_1 / h
979
+ D0 = m0
980
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
981
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
982
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
983
+ if self.config.algorithm_type == "dpmsolver++":
984
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
985
+ x_t = (
986
+ (sigma_t / sigma_s0) * sample
987
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
988
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
989
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
990
+ )
991
+ elif self.config.algorithm_type == "dpmsolver":
992
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
993
+ x_t = (
994
+ (alpha_t / alpha_s0) * sample
995
+ - (sigma_t * (torch.exp(h) - 1.0)) * D0
996
+ - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
997
+ - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
998
+ )
999
+ elif self.config.algorithm_type == "sde-dpmsolver++":
1000
+ assert noise is not None
1001
+ x_t = (
1002
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
1003
+ + (alpha_t * (1.0 - torch.exp(-2.0 * h))) * D0
1004
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
1005
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h) - 2.0 * h) / (2.0 * h) ** 2 - 0.5)) * D2
1006
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
1007
+ )
1008
+ return x_t
1009
+
1010
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
1011
+ if schedule_timesteps is None:
1012
+ schedule_timesteps = self.timesteps
1013
+
1014
+ index_candidates = (schedule_timesteps == timestep).nonzero()
1015
+
1016
+ if len(index_candidates) == 0:
1017
+ step_index = len(self.timesteps) - 1
1018
+ # The sigma index that is taken for the **very** first `step`
1019
+ # is always the second index (or the last index if there is only 1)
1020
+ # This way we can ensure we don't accidentally skip a sigma in
1021
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
1022
+ elif len(index_candidates) > 1:
1023
+ step_index = index_candidates[1].item()
1024
+ else:
1025
+ step_index = index_candidates[0].item()
1026
+
1027
+ return step_index
1028
+
1029
+ def _init_step_index(self, timestep):
1030
+ """
1031
+ Initialize the step_index counter for the scheduler.
1032
+ """
1033
+
1034
+ if self.begin_index is None:
1035
+ if isinstance(timestep, torch.Tensor):
1036
+ timestep = timestep.to(self.timesteps.device)
1037
+ self._step_index = self.index_for_timestep(timestep)
1038
+ else:
1039
+ self._step_index = self._begin_index
1040
+
1041
+ def step(
1042
+ self,
1043
+ model_output: torch.Tensor,
1044
+ timestep: Union[int, torch.Tensor],
1045
+ sample: torch.Tensor,
1046
+ generator=None,
1047
+ variance_noise: Optional[torch.Tensor] = None,
1048
+ return_dict: bool = True,
1049
+ ) -> Union[SchedulerOutput, Tuple]:
1050
+ """
1051
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
1052
+ the multistep DPMSolver.
1053
+
1054
+ Args:
1055
+ model_output (`torch.Tensor`):
1056
+ The direct output from learned diffusion model.
1057
+ timestep (`int`):
1058
+ The current discrete timestep in the diffusion chain.
1059
+ sample (`torch.Tensor`):
1060
+ A current instance of a sample created by the diffusion process.
1061
+ generator (`torch.Generator`, *optional*):
1062
+ A random number generator.
1063
+ variance_noise (`torch.Tensor`):
1064
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
1065
+ itself. Useful for methods such as [`LEdits++`].
1066
+ return_dict (`bool`):
1067
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
1068
+
1069
+ Returns:
1070
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
1071
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
1072
+ tuple is returned where the first element is the sample tensor.
1073
+
1074
+ """
1075
+ if self.num_inference_steps is None:
1076
+ raise ValueError(
1077
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
1078
+ )
1079
+
1080
+ if self.step_index is None:
1081
+ self._init_step_index(timestep)
1082
+
1083
+ # Improve numerical stability for small number of steps
1084
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
1085
+ self.config.euler_at_final
1086
+ or (self.config.lower_order_final and len(self.timesteps) < 15)
1087
+ or self.config.final_sigmas_type == "zero"
1088
+ )
1089
+ lower_order_second = (
1090
+ (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
1091
+ )
1092
+
1093
+ model_output = self.convert_model_output(model_output, sample=sample)
1094
+ for i in range(self.config.solver_order - 1):
1095
+ self.model_outputs[i] = self.model_outputs[i + 1]
1096
+ self.model_outputs[-1] = model_output
1097
+
1098
+ # Upcast to avoid precision issues when computing prev_sample
1099
+ sample = sample.to(torch.float32)
1100
+ if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
1101
+ noise = randn_tensor(
1102
+ model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
1103
+ )
1104
+ elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
1105
+ noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
1106
+ else:
1107
+ noise = None
1108
+
1109
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
1110
+ prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
1111
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
1112
+ prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
1113
+ else:
1114
+ prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample, noise=noise)
1115
+
1116
+ if self.lower_order_nums < self.config.solver_order:
1117
+ self.lower_order_nums += 1
1118
+
1119
+ # Cast sample back to expected dtype
1120
+ prev_sample = prev_sample.to(model_output.dtype)
1121
+
1122
+ # upon completion increase step index by one
1123
+ self._step_index += 1
1124
+
1125
+ if not return_dict:
1126
+ return (prev_sample,)
1127
+
1128
+ return SchedulerOutput(prev_sample=prev_sample)
1129
+
1130
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1131
+ """
1132
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
1133
+ current timestep.
1134
+
1135
+ Args:
1136
+ sample (`torch.Tensor`):
1137
+ The input sample.
1138
+
1139
+ Returns:
1140
+ `torch.Tensor`:
1141
+ A scaled input sample.
1142
+ """
1143
+ return sample
1144
+
1145
+ def add_noise(
1146
+ self,
1147
+ original_samples: torch.Tensor,
1148
+ noise: torch.Tensor,
1149
+ timesteps: torch.IntTensor,
1150
+ ) -> torch.Tensor:
1151
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
1152
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
1153
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
1154
+ # mps does not support float64
1155
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
1156
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
1157
+ else:
1158
+ schedule_timesteps = self.timesteps.to(original_samples.device)
1159
+ timesteps = timesteps.to(original_samples.device)
1160
+
1161
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
1162
+ if self.begin_index is None:
1163
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
1164
+ elif self.step_index is not None:
1165
+ # add_noise is called after first denoising step (for inpainting)
1166
+ step_indices = [self.step_index] * timesteps.shape[0]
1167
+ else:
1168
+ # add noise is called before first denoising step to create initial latent(img2img)
1169
+ step_indices = [self.begin_index] * timesteps.shape[0]
1170
+
1171
+ sigma = sigmas[step_indices].flatten()
1172
+ while len(sigma.shape) < len(original_samples.shape):
1173
+ sigma = sigma.unsqueeze(-1)
1174
+
1175
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
1176
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
1177
+ return noisy_samples
1178
+
1179
+ def __len__(self):
1180
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import flax
21
+ import jax
22
+ import jax.numpy as jnp
23
+
24
+ from ..configuration_utils import ConfigMixin, register_to_config
25
+ from .scheduling_utils_flax import (
26
+ CommonSchedulerState,
27
+ FlaxKarrasDiffusionSchedulers,
28
+ FlaxSchedulerMixin,
29
+ FlaxSchedulerOutput,
30
+ add_noise_common,
31
+ )
32
+
33
+
34
+ @flax.struct.dataclass
35
+ class DPMSolverMultistepSchedulerState:
36
+ common: CommonSchedulerState
37
+ alpha_t: jnp.ndarray
38
+ sigma_t: jnp.ndarray
39
+ lambda_t: jnp.ndarray
40
+
41
+ # setable values
42
+ init_noise_sigma: jnp.ndarray
43
+ timesteps: jnp.ndarray
44
+ num_inference_steps: Optional[int] = None
45
+
46
+ # running values
47
+ model_outputs: Optional[jnp.ndarray] = None
48
+ lower_order_nums: Optional[jnp.int32] = None
49
+ prev_timestep: Optional[jnp.int32] = None
50
+ cur_sample: Optional[jnp.ndarray] = None
51
+
52
+ @classmethod
53
+ def create(
54
+ cls,
55
+ common: CommonSchedulerState,
56
+ alpha_t: jnp.ndarray,
57
+ sigma_t: jnp.ndarray,
58
+ lambda_t: jnp.ndarray,
59
+ init_noise_sigma: jnp.ndarray,
60
+ timesteps: jnp.ndarray,
61
+ ):
62
+ return cls(
63
+ common=common,
64
+ alpha_t=alpha_t,
65
+ sigma_t=sigma_t,
66
+ lambda_t=lambda_t,
67
+ init_noise_sigma=init_noise_sigma,
68
+ timesteps=timesteps,
69
+ )
70
+
71
+
72
+ @dataclass
73
+ class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput):
74
+ state: DPMSolverMultistepSchedulerState
75
+
76
+
77
+ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
78
+ """
79
+ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with
80
+ the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
81
+ samples, and it can generate quite good samples even in only 10 steps.
82
+
83
+ For more details, see the original paper: https://huggingface.co/papers/2206.00927 and
84
+ https://huggingface.co/papers/2211.01095
85
+
86
+ Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
87
+ recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
88
+
89
+ We also support the "dynamic thresholding" method in Imagen (https://huggingface.co/papers/2205.11487). For
90
+ pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the
91
+ dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
92
+ stable-diffusion).
93
+
94
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
95
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
96
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
97
+ [`~SchedulerMixin.from_pretrained`] functions.
98
+
99
+ For more details, see the original paper: https://huggingface.co/papers/2206.00927 and
100
+ https://huggingface.co/papers/2211.01095
101
+
102
+ Args:
103
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
104
+ beta_start (`float`): the starting `beta` value of inference.
105
+ beta_end (`float`): the final `beta` value.
106
+ beta_schedule (`str`):
107
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
108
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
109
+ trained_betas (`np.ndarray`, optional):
110
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
111
+ solver_order (`int`, default `2`):
112
+ the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
113
+ sampling, and `solver_order=3` for unconditional sampling.
114
+ prediction_type (`str`, default `epsilon`):
115
+ indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
116
+ or `v-prediction`.
117
+ thresholding (`bool`, default `False`):
118
+ whether to use the "dynamic thresholding" method (introduced by Imagen,
119
+ https://huggingface.co/papers/2205.11487). For pixel-space diffusion models, you can set both
120
+ `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the
121
+ thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion).
122
+ dynamic_thresholding_ratio (`float`, default `0.995`):
123
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
124
+ (https://huggingface.co/papers/2205.11487).
125
+ sample_max_value (`float`, default `1.0`):
126
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
127
+ `algorithm_type="dpmsolver++`.
128
+ algorithm_type (`str`, default `dpmsolver++`):
129
+ the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
130
+ algorithms in https://huggingface.co/papers/2206.00927, and the `dpmsolver++` type implements the
131
+ algorithms in https://huggingface.co/papers/2211.01095. We recommend to use `dpmsolver++` with
132
+ `solver_order=2` for guided sampling (e.g. stable-diffusion).
133
+ solver_type (`str`, default `midpoint`):
134
+ the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
135
+ the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
136
+ slightly better, so we recommend to use the `midpoint` type.
137
+ lower_order_final (`bool`, default `True`):
138
+ whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
139
+ find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
140
+ timestep_spacing (`str`, defaults to `"linspace"`):
141
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
142
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
143
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
144
+ the `dtype` used for params and computation.
145
+ """
146
+
147
+ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
148
+
149
+ dtype: jnp.dtype
150
+
151
+ @property
152
+ def has_state(self):
153
+ return True
154
+
155
+ @register_to_config
156
+ def __init__(
157
+ self,
158
+ num_train_timesteps: int = 1000,
159
+ beta_start: float = 0.0001,
160
+ beta_end: float = 0.02,
161
+ beta_schedule: str = "linear",
162
+ trained_betas: Optional[jnp.ndarray] = None,
163
+ solver_order: int = 2,
164
+ prediction_type: str = "epsilon",
165
+ thresholding: bool = False,
166
+ dynamic_thresholding_ratio: float = 0.995,
167
+ sample_max_value: float = 1.0,
168
+ algorithm_type: str = "dpmsolver++",
169
+ solver_type: str = "midpoint",
170
+ lower_order_final: bool = True,
171
+ timestep_spacing: str = "linspace",
172
+ dtype: jnp.dtype = jnp.float32,
173
+ ):
174
+ self.dtype = dtype
175
+
176
+ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
177
+ if common is None:
178
+ common = CommonSchedulerState.create(self)
179
+
180
+ # Currently we only support VP-type noise schedule
181
+ alpha_t = jnp.sqrt(common.alphas_cumprod)
182
+ sigma_t = jnp.sqrt(1 - common.alphas_cumprod)
183
+ lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t)
184
+
185
+ # settings for DPM-Solver
186
+ if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]:
187
+ raise NotImplementedError(f"{self.config.algorithm_type} is not implemented for {self.__class__}")
188
+ if self.config.solver_type not in ["midpoint", "heun"]:
189
+ raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}")
190
+
191
+ # standard deviation of the initial noise distribution
192
+ init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
193
+
194
+ timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
195
+
196
+ return DPMSolverMultistepSchedulerState.create(
197
+ common=common,
198
+ alpha_t=alpha_t,
199
+ sigma_t=sigma_t,
200
+ lambda_t=lambda_t,
201
+ init_noise_sigma=init_noise_sigma,
202
+ timesteps=timesteps,
203
+ )
204
+
205
+ def set_timesteps(
206
+ self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple
207
+ ) -> DPMSolverMultistepSchedulerState:
208
+ """
209
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
210
+
211
+ Args:
212
+ state (`DPMSolverMultistepSchedulerState`):
213
+ the `FlaxDPMSolverMultistepScheduler` state data class instance.
214
+ num_inference_steps (`int`):
215
+ the number of diffusion steps used when generating samples with a pre-trained model.
216
+ shape (`Tuple`):
217
+ the shape of the samples to be generated.
218
+ """
219
+ last_timestep = self.config.num_train_timesteps
220
+ if self.config.timestep_spacing == "linspace":
221
+ timesteps = (
222
+ jnp.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32)
223
+ )
224
+ elif self.config.timestep_spacing == "leading":
225
+ step_ratio = last_timestep // (num_inference_steps + 1)
226
+ # creates integer timesteps by multiplying by ratio
227
+ # casting to int to avoid issues when num_inference_step is power of 3
228
+ timesteps = (
229
+ (jnp.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(jnp.int32)
230
+ )
231
+ timesteps += self.config.steps_offset
232
+ elif self.config.timestep_spacing == "trailing":
233
+ step_ratio = self.config.num_train_timesteps / num_inference_steps
234
+ # creates integer timesteps by multiplying by ratio
235
+ # casting to int to avoid issues when num_inference_step is power of 3
236
+ timesteps = jnp.arange(last_timestep, 0, -step_ratio).round().copy().astype(jnp.int32)
237
+ timesteps -= 1
238
+ else:
239
+ raise ValueError(
240
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
241
+ )
242
+
243
+ # initial running values
244
+
245
+ model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype)
246
+ lower_order_nums = jnp.int32(0)
247
+ prev_timestep = jnp.int32(-1)
248
+ cur_sample = jnp.zeros(shape, dtype=self.dtype)
249
+
250
+ return state.replace(
251
+ num_inference_steps=num_inference_steps,
252
+ timesteps=timesteps,
253
+ model_outputs=model_outputs,
254
+ lower_order_nums=lower_order_nums,
255
+ prev_timestep=prev_timestep,
256
+ cur_sample=cur_sample,
257
+ )
258
+
259
+ def convert_model_output(
260
+ self,
261
+ state: DPMSolverMultistepSchedulerState,
262
+ model_output: jnp.ndarray,
263
+ timestep: int,
264
+ sample: jnp.ndarray,
265
+ ) -> jnp.ndarray:
266
+ """
267
+ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
268
+
269
+ DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
270
+ discretize an integral of the data prediction model. So we need to first convert the model output to the
271
+ corresponding type to match the algorithm.
272
+
273
+ Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or
274
+ DPM-Solver++ for both noise prediction model and data prediction model.
275
+
276
+ Args:
277
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
278
+ timestep (`int`): current discrete timestep in the diffusion chain.
279
+ sample (`jnp.ndarray`):
280
+ current instance of sample being created by diffusion process.
281
+
282
+ Returns:
283
+ `jnp.ndarray`: the converted model output.
284
+ """
285
+ # DPM-Solver++ needs to solve an integral of the data prediction model.
286
+ if self.config.algorithm_type == "dpmsolver++":
287
+ if self.config.prediction_type == "epsilon":
288
+ alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
289
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
290
+ elif self.config.prediction_type == "sample":
291
+ x0_pred = model_output
292
+ elif self.config.prediction_type == "v_prediction":
293
+ alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
294
+ x0_pred = alpha_t * sample - sigma_t * model_output
295
+ else:
296
+ raise ValueError(
297
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
298
+ " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
299
+ )
300
+
301
+ if self.config.thresholding:
302
+ # Dynamic thresholding in https://huggingface.co/papers/2205.11487
303
+ dynamic_max_val = jnp.percentile(
304
+ jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
305
+ )
306
+ dynamic_max_val = jnp.maximum(
307
+ dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val)
308
+ )
309
+ x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
310
+ return x0_pred
311
+ # DPM-Solver needs to solve an integral of the noise prediction model.
312
+ elif self.config.algorithm_type == "dpmsolver":
313
+ if self.config.prediction_type == "epsilon":
314
+ return model_output
315
+ elif self.config.prediction_type == "sample":
316
+ alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
317
+ epsilon = (sample - alpha_t * model_output) / sigma_t
318
+ return epsilon
319
+ elif self.config.prediction_type == "v_prediction":
320
+ alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep]
321
+ epsilon = alpha_t * model_output + sigma_t * sample
322
+ return epsilon
323
+ else:
324
+ raise ValueError(
325
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
326
+ " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
327
+ )
328
+
329
+ def dpm_solver_first_order_update(
330
+ self,
331
+ state: DPMSolverMultistepSchedulerState,
332
+ model_output: jnp.ndarray,
333
+ timestep: int,
334
+ prev_timestep: int,
335
+ sample: jnp.ndarray,
336
+ ) -> jnp.ndarray:
337
+ """
338
+ One step for the first-order DPM-Solver (equivalent to DDIM).
339
+
340
+ See https://huggingface.co/papers/2206.00927 for the detailed derivation.
341
+
342
+ Args:
343
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
344
+ timestep (`int`): current discrete timestep in the diffusion chain.
345
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
346
+ sample (`jnp.ndarray`):
347
+ current instance of sample being created by diffusion process.
348
+
349
+ Returns:
350
+ `jnp.ndarray`: the sample tensor at the previous timestep.
351
+ """
352
+ t, s0 = prev_timestep, timestep
353
+ m0 = model_output
354
+ lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0]
355
+ alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0]
356
+ sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0]
357
+ h = lambda_t - lambda_s
358
+ if self.config.algorithm_type == "dpmsolver++":
359
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0
360
+ elif self.config.algorithm_type == "dpmsolver":
361
+ x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0
362
+ return x_t
363
+
364
+ def multistep_dpm_solver_second_order_update(
365
+ self,
366
+ state: DPMSolverMultistepSchedulerState,
367
+ model_output_list: jnp.ndarray,
368
+ timestep_list: List[int],
369
+ prev_timestep: int,
370
+ sample: jnp.ndarray,
371
+ ) -> jnp.ndarray:
372
+ """
373
+ One step for the second-order multistep DPM-Solver.
374
+
375
+ Args:
376
+ model_output_list (`List[jnp.ndarray]`):
377
+ direct outputs from learned diffusion model at current and latter timesteps.
378
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
379
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
380
+ sample (`jnp.ndarray`):
381
+ current instance of sample being created by diffusion process.
382
+
383
+ Returns:
384
+ `jnp.ndarray`: the sample tensor at the previous timestep.
385
+ """
386
+ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
387
+ m0, m1 = model_output_list[-1], model_output_list[-2]
388
+ lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1]
389
+ alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
390
+ sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
391
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
392
+ r0 = h_0 / h
393
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
394
+ if self.config.algorithm_type == "dpmsolver++":
395
+ # See https://huggingface.co/papers/2211.01095 for detailed derivations
396
+ if self.config.solver_type == "midpoint":
397
+ x_t = (
398
+ (sigma_t / sigma_s0) * sample
399
+ - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
400
+ - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1
401
+ )
402
+ elif self.config.solver_type == "heun":
403
+ x_t = (
404
+ (sigma_t / sigma_s0) * sample
405
+ - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
406
+ + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
407
+ )
408
+ elif self.config.algorithm_type == "dpmsolver":
409
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
410
+ if self.config.solver_type == "midpoint":
411
+ x_t = (
412
+ (alpha_t / alpha_s0) * sample
413
+ - (sigma_t * (jnp.exp(h) - 1.0)) * D0
414
+ - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1
415
+ )
416
+ elif self.config.solver_type == "heun":
417
+ x_t = (
418
+ (alpha_t / alpha_s0) * sample
419
+ - (sigma_t * (jnp.exp(h) - 1.0)) * D0
420
+ - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
421
+ )
422
+ return x_t
423
+
424
+ def multistep_dpm_solver_third_order_update(
425
+ self,
426
+ state: DPMSolverMultistepSchedulerState,
427
+ model_output_list: jnp.ndarray,
428
+ timestep_list: List[int],
429
+ prev_timestep: int,
430
+ sample: jnp.ndarray,
431
+ ) -> jnp.ndarray:
432
+ """
433
+ One step for the third-order multistep DPM-Solver.
434
+
435
+ Args:
436
+ model_output_list (`List[jnp.ndarray]`):
437
+ direct outputs from learned diffusion model at current and latter timesteps.
438
+ timestep (`int`): current and latter discrete timestep in the diffusion chain.
439
+ prev_timestep (`int`): previous discrete timestep in the diffusion chain.
440
+ sample (`jnp.ndarray`):
441
+ current instance of sample being created by diffusion process.
442
+
443
+ Returns:
444
+ `jnp.ndarray`: the sample tensor at the previous timestep.
445
+ """
446
+ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
447
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
448
+ lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
449
+ state.lambda_t[t],
450
+ state.lambda_t[s0],
451
+ state.lambda_t[s1],
452
+ state.lambda_t[s2],
453
+ )
454
+ alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0]
455
+ sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0]
456
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
457
+ r0, r1 = h_0 / h, h_1 / h
458
+ D0 = m0
459
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
460
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
461
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
462
+ if self.config.algorithm_type == "dpmsolver++":
463
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
464
+ x_t = (
465
+ (sigma_t / sigma_s0) * sample
466
+ - (alpha_t * (jnp.exp(-h) - 1.0)) * D0
467
+ + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
468
+ - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
469
+ )
470
+ elif self.config.algorithm_type == "dpmsolver":
471
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
472
+ x_t = (
473
+ (alpha_t / alpha_s0) * sample
474
+ - (sigma_t * (jnp.exp(h) - 1.0)) * D0
475
+ - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1
476
+ - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2
477
+ )
478
+ return x_t
479
+
480
+ def step(
481
+ self,
482
+ state: DPMSolverMultistepSchedulerState,
483
+ model_output: jnp.ndarray,
484
+ timestep: int,
485
+ sample: jnp.ndarray,
486
+ return_dict: bool = True,
487
+ ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]:
488
+ """
489
+ Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process
490
+ from the learned model outputs (most often the predicted noise).
491
+
492
+ Args:
493
+ state (`DPMSolverMultistepSchedulerState`):
494
+ the `FlaxDPMSolverMultistepScheduler` state data class instance.
495
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
496
+ timestep (`int`): current discrete timestep in the diffusion chain.
497
+ sample (`jnp.ndarray`):
498
+ current instance of sample being created by diffusion process.
499
+ return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class
500
+
501
+ Returns:
502
+ [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if
503
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
504
+
505
+ """
506
+ if state.num_inference_steps is None:
507
+ raise ValueError(
508
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
509
+ )
510
+
511
+ (step_index,) = jnp.where(state.timesteps == timestep, size=1)
512
+ step_index = step_index[0]
513
+
514
+ prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1])
515
+
516
+ model_output = self.convert_model_output(state, model_output, timestep, sample)
517
+
518
+ model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0)
519
+ model_outputs_new = model_outputs_new.at[-1].set(model_output)
520
+ state = state.replace(
521
+ model_outputs=model_outputs_new,
522
+ prev_timestep=prev_timestep,
523
+ cur_sample=sample,
524
+ )
525
+
526
+ def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
527
+ return self.dpm_solver_first_order_update(
528
+ state,
529
+ state.model_outputs[-1],
530
+ state.timesteps[step_index],
531
+ state.prev_timestep,
532
+ state.cur_sample,
533
+ )
534
+
535
+ def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
536
+ def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
537
+ timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]])
538
+ return self.multistep_dpm_solver_second_order_update(
539
+ state,
540
+ state.model_outputs,
541
+ timestep_list,
542
+ state.prev_timestep,
543
+ state.cur_sample,
544
+ )
545
+
546
+ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray:
547
+ timestep_list = jnp.array(
548
+ [
549
+ state.timesteps[step_index - 2],
550
+ state.timesteps[step_index - 1],
551
+ state.timesteps[step_index],
552
+ ]
553
+ )
554
+ return self.multistep_dpm_solver_third_order_update(
555
+ state,
556
+ state.model_outputs,
557
+ timestep_list,
558
+ state.prev_timestep,
559
+ state.cur_sample,
560
+ )
561
+
562
+ step_2_output = step_2(state)
563
+ step_3_output = step_3(state)
564
+
565
+ if self.config.solver_order == 2:
566
+ return step_2_output
567
+ elif self.config.lower_order_final and len(state.timesteps) < 15:
568
+ return jax.lax.select(
569
+ state.lower_order_nums < 2,
570
+ step_2_output,
571
+ jax.lax.select(
572
+ step_index == len(state.timesteps) - 2,
573
+ step_2_output,
574
+ step_3_output,
575
+ ),
576
+ )
577
+ else:
578
+ return jax.lax.select(
579
+ state.lower_order_nums < 2,
580
+ step_2_output,
581
+ step_3_output,
582
+ )
583
+
584
+ step_1_output = step_1(state)
585
+ step_23_output = step_23(state)
586
+
587
+ if self.config.solver_order == 1:
588
+ prev_sample = step_1_output
589
+
590
+ elif self.config.lower_order_final and len(state.timesteps) < 15:
591
+ prev_sample = jax.lax.select(
592
+ state.lower_order_nums < 1,
593
+ step_1_output,
594
+ jax.lax.select(
595
+ step_index == len(state.timesteps) - 1,
596
+ step_1_output,
597
+ step_23_output,
598
+ ),
599
+ )
600
+
601
+ else:
602
+ prev_sample = jax.lax.select(
603
+ state.lower_order_nums < 1,
604
+ step_1_output,
605
+ step_23_output,
606
+ )
607
+
608
+ state = state.replace(
609
+ lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order),
610
+ )
611
+
612
+ if not return_dict:
613
+ return (prev_sample, state)
614
+
615
+ return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state)
616
+
617
+ def scale_model_input(
618
+ self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
619
+ ) -> jnp.ndarray:
620
+ """
621
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
622
+ current timestep.
623
+
624
+ Args:
625
+ state (`DPMSolverMultistepSchedulerState`):
626
+ the `FlaxDPMSolverMultistepScheduler` state data class instance.
627
+ sample (`jnp.ndarray`): input sample
628
+ timestep (`int`, optional): current timestep
629
+
630
+ Returns:
631
+ `jnp.ndarray`: scaled input sample
632
+ """
633
+ return sample
634
+
635
+ def add_noise(
636
+ self,
637
+ state: DPMSolverMultistepSchedulerState,
638
+ original_samples: jnp.ndarray,
639
+ noise: jnp.ndarray,
640
+ timesteps: jnp.ndarray,
641
+ ) -> jnp.ndarray:
642
+ return add_noise_common(state.common, original_samples, noise, timesteps)
643
+
644
+ def __len__(self):
645
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/functorch/compile/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch._functorch import config
2
+ from torch._functorch.aot_autograd import (
3
+ aot_function,
4
+ aot_module,
5
+ aot_module_simplified,
6
+ compiled_function,
7
+ compiled_module,
8
+ get_aot_compilation_context,
9
+ get_aot_graph_name,
10
+ get_graph_being_compiled,
11
+ make_boxed_compiler,
12
+ make_boxed_func,
13
+ )
14
+ from torch._functorch.compilers import (
15
+ debug_compile,
16
+ default_decompositions,
17
+ draw_graph_compile,
18
+ memory_efficient_fusion,
19
+ nnc_jit,
20
+ nop,
21
+ print_compile,
22
+ ts_compile,
23
+ )
24
+ from torch._functorch.fx_minifier import minifier
25
+ from torch._functorch.partitioners import (
26
+ default_partition,
27
+ draw_graph,
28
+ min_cut_rematerialization_partition,
29
+ )
30
+ from torch._functorch.python_key import pythonkey_decompose
pythonProject/.venv/Lib/site-packages/functorch/compile/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.06 kB). View file
 
pythonProject/.venv/Lib/site-packages/functorch/dim/batch_tensor.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from contextlib import contextmanager
7
+
8
+ from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
9
+
10
+ _enabled = False
11
+
12
+
13
+ @contextmanager
14
+ def _enable_layers(dims):
15
+ global _enabled
16
+ assert not _enabled
17
+ input = sorted((d._level, d.size) for d in dims if not isinstance(d, int))
18
+ n = len(input)
19
+ try:
20
+ _vmap_add_layers(input)
21
+ _enabled = True
22
+ yield
23
+ finally:
24
+ _enabled = False
25
+ _vmap_remove_layers(n)
pythonProject/.venv/Lib/site-packages/functorch/dim/delayed_mul_tensor.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import torch
7
+
8
+ from . import _Tensor, Tensor
9
+ from .reference import _dims, _enable_layers, llist, ltuple
10
+
11
+
12
+ class DelayedMulTensor(_Tensor):
13
+ def __init__(self, lhs, rhs):
14
+ self._lhs, self._rhs = lhs, rhs
15
+ self._data = None
16
+ self._levels_data = None
17
+ self._has_device = lhs._has_device or rhs._has_device
18
+ self._batchtensor_data = None
19
+ self._tensor_data = None
20
+
21
+ @property
22
+ def _levels(self):
23
+ if self._levels_data is None:
24
+ levels = llist(self._lhs._levels)
25
+ for l in self._rhs._levels:
26
+ if l not in levels:
27
+ levels.append(l)
28
+ self._levels_data = ltuple(levels)
29
+ return self._levels_data
30
+
31
+ @property
32
+ def _batchtensor(self):
33
+ if self._batchtensor_data is None:
34
+ with _enable_layers(self._levels):
35
+ print("bt multiply fallback")
36
+ self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
37
+ return self._batchtensor_data
38
+
39
+ @property
40
+ def _tensor(self):
41
+ if self._tensor_data is None:
42
+ self._tensor_data = Tensor.from_batched(
43
+ self._batchtensor, self._has_device
44
+ )._tensor
45
+ return self._tensor_data
46
+
47
+ @property
48
+ def ndim(self):
49
+ return self._batchtensor.ndim
50
+
51
+ @property
52
+ def dims(self):
53
+ return ltuple(super().dims)
54
+
55
+ def sum(self, dim):
56
+ dims = _dims(dim, 0, False, False)
57
+ n = ord("a")
58
+ all_levels = self._levels
59
+
60
+ def to_char(d):
61
+ return chr(n + all_levels.index(d))
62
+
63
+ plhs, levelslhs = self._lhs._tensor, self._lhs._levels
64
+ prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
65
+ new_dims = tuple(d for d in self.dims if d not in dims)
66
+ new_levels = [l for l in self._levels if l not in dims]
67
+ fmt = "".join(
68
+ [
69
+ *(to_char(d) for d in levelslhs),
70
+ ",",
71
+ *(to_char(d) for d in levelsrhs),
72
+ "->",
73
+ *(to_char(d) for d in new_levels),
74
+ ]
75
+ )
76
+ result_data = torch.einsum(fmt, (plhs, prhs))
77
+ return Tensor.from_positional(result_data, new_levels, True)
pythonProject/.venv/Lib/site-packages/functorch/dim/dim.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import dis
7
+ import inspect
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Union
11
+
12
+ from . import DimList
13
+
14
+ _vmap_levels = []
15
+
16
+
17
+ @dataclass
18
+ class LevelInfo:
19
+ level: int
20
+ alive: bool = True
21
+
22
+
23
+ class Dim:
24
+ def __init__(self, name: str, size: Union[None, int] = None):
25
+ self.name = name
26
+ self._size = None
27
+ self._vmap_level = None
28
+ if size is not None:
29
+ self.size = size
30
+
31
+ def __del__(self):
32
+ if self._vmap_level is not None:
33
+ _vmap_active_levels[self._vmap_stack].alive = False # noqa: F821
34
+ while (
35
+ not _vmap_levels[-1].alive
36
+ and current_level() == _vmap_levels[-1].level # noqa: F821
37
+ ):
38
+ _vmap_decrement_nesting() # noqa: F821
39
+ _vmap_levels.pop()
40
+
41
+ @property
42
+ def size(self):
43
+ assert self.is_bound
44
+ return self._size
45
+
46
+ @size.setter
47
+ def size(self, size: int):
48
+ from . import DimensionBindError
49
+
50
+ if self._size is None:
51
+ self._size = size
52
+ self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821
53
+ self._vmap_stack = len(_vmap_levels)
54
+ _vmap_levels.append(LevelInfo(self._vmap_level))
55
+
56
+ elif self._size != size:
57
+ raise DimensionBindError(
58
+ f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
59
+ )
60
+
61
+ @property
62
+ def is_bound(self):
63
+ return self._size is not None
64
+
65
+ def __repr__(self):
66
+ return self.name
67
+
68
+
69
+ def extract_name(inst):
70
+ assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
71
+ return inst.argval
72
+
73
+
74
+ _cache = {}
75
+
76
+
77
+ def dims(lists=0):
78
+ frame = inspect.currentframe()
79
+ assert frame is not None
80
+ calling_frame = frame.f_back
81
+ assert calling_frame is not None
82
+ code, lasti = calling_frame.f_code, calling_frame.f_lasti
83
+ key = (code, lasti)
84
+ if key not in _cache:
85
+ first = lasti // 2 + 1
86
+ instructions = list(dis.get_instructions(calling_frame.f_code))
87
+ unpack = instructions[first]
88
+
89
+ if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
90
+ # just a single dim, not a list
91
+ name = unpack.argval
92
+ ctor = Dim if lists == 0 else DimList
93
+ _cache[key] = lambda: ctor(name=name)
94
+ else:
95
+ assert unpack.opname == "UNPACK_SEQUENCE"
96
+ ndims = unpack.argval
97
+ names = tuple(
98
+ extract_name(instructions[first + 1 + i]) for i in range(ndims)
99
+ )
100
+ first_list = len(names) - lists
101
+ _cache[key] = lambda: tuple(
102
+ Dim(n) if i < first_list else DimList(name=n)
103
+ for i, n in enumerate(names)
104
+ )
105
+ return _cache[key]()
106
+
107
+
108
+ def _dim_set(positional, arg):
109
+ def convert(a):
110
+ if isinstance(a, Dim):
111
+ return a
112
+ else:
113
+ assert isinstance(a, int)
114
+ return positional[a]
115
+
116
+ if arg is None:
117
+ return positional
118
+ elif not isinstance(arg, (Dim, int)):
119
+ return tuple(convert(a) for a in arg)
120
+ else:
121
+ return (convert(arg),)