Add files using upload-large-folder tool
Browse files- pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +712 -0
- pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_edm_euler.py +448 -0
- pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +482 -0
- pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_discrete.py +757 -0
- pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_discrete_flax.py +265 -0
- pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +561 -0
- pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- pythonProject/.venv/Lib/site-packages/fsspec/__init__.py +71 -0
- pythonProject/.venv/Lib/site-packages/fsspec/_version.py +34 -0
- pythonProject/.venv/Lib/site-packages/fsspec/implementations/arrow.py +304 -0
- pythonProject/.venv/Lib/site-packages/fsspec/implementations/asyn_wrapper.py +122 -0
- pythonProject/.venv/Lib/site-packages/fsspec/implementations/cache_mapper.py +75 -0
- pythonProject/.venv/Lib/site-packages/fsspec/implementations/cache_metadata.py +233 -0
- pythonProject/.venv/Lib/site-packages/fsspec/implementations/cached.py +998 -0
- pythonProject/.venv/Lib/site-packages/fsspec/implementations/dask.py +152 -0
- pythonProject/.venv/Lib/site-packages/fsspec/implementations/data.py +58 -0
- pythonProject/.venv/Lib/site-packages/fsspec/implementations/dbfs.py +496 -0
- pythonProject/.venv/Lib/site-packages/fsspec/implementations/dirfs.py +388 -0
- pythonProject/.venv/Lib/site-packages/fsspec/utils.py +737 -0
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_dpm_cogvideox.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
| 17 |
+
# and https://github.com/hojonathanho/diffusion
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
| 27 |
+
from ..utils import BaseOutput
|
| 28 |
+
from ..utils.torch_utils import randn_tensor
|
| 29 |
+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
| 34 |
+
class DDIMSchedulerOutput(BaseOutput):
|
| 35 |
+
"""
|
| 36 |
+
Output class for the scheduler's `step` function output.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 40 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 41 |
+
denoising loop.
|
| 42 |
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 43 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
| 44 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
prev_sample: torch.Tensor
|
| 48 |
+
pred_original_sample: Optional[torch.Tensor] = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
| 52 |
+
def betas_for_alpha_bar(
|
| 53 |
+
num_diffusion_timesteps,
|
| 54 |
+
max_beta=0.999,
|
| 55 |
+
alpha_transform_type="cosine",
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
| 59 |
+
(1-beta) over time from t = [0,1].
|
| 60 |
+
|
| 61 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
| 62 |
+
to that part of the diffusion process.
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
| 67 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
| 68 |
+
prevent singularities.
|
| 69 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
| 70 |
+
Choose from `cosine` or `exp`
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
| 74 |
+
"""
|
| 75 |
+
if alpha_transform_type == "cosine":
|
| 76 |
+
|
| 77 |
+
def alpha_bar_fn(t):
|
| 78 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 79 |
+
|
| 80 |
+
elif alpha_transform_type == "exp":
|
| 81 |
+
|
| 82 |
+
def alpha_bar_fn(t):
|
| 83 |
+
return math.exp(t * -12.0)
|
| 84 |
+
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
| 87 |
+
|
| 88 |
+
betas = []
|
| 89 |
+
for i in range(num_diffusion_timesteps):
|
| 90 |
+
t1 = i / num_diffusion_timesteps
|
| 91 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 92 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
| 93 |
+
return torch.tensor(betas, dtype=torch.float32)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def rescale_zero_terminal_snr(alphas_cumprod):
|
| 97 |
+
"""
|
| 98 |
+
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
betas (`torch.Tensor`):
|
| 103 |
+
the betas that the scheduler is being initialized with.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 110 |
+
|
| 111 |
+
# Store old values.
|
| 112 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 113 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 114 |
+
|
| 115 |
+
# Shift so the last timestep is zero.
|
| 116 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 117 |
+
|
| 118 |
+
# Scale so the first timestep is back to the old value.
|
| 119 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 120 |
+
|
| 121 |
+
# Convert alphas_bar_sqrt to betas
|
| 122 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
| 123 |
+
|
| 124 |
+
return alphas_bar
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
|
| 128 |
+
"""
|
| 129 |
+
`DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
| 130 |
+
non-Markovian guidance.
|
| 131 |
+
|
| 132 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 133 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 137 |
+
The number of diffusion steps to train the model.
|
| 138 |
+
beta_start (`float`, defaults to 0.0001):
|
| 139 |
+
The starting `beta` value of inference.
|
| 140 |
+
beta_end (`float`, defaults to 0.02):
|
| 141 |
+
The final `beta` value.
|
| 142 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
| 143 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
| 144 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
| 145 |
+
trained_betas (`np.ndarray`, *optional*):
|
| 146 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
| 147 |
+
clip_sample (`bool`, defaults to `True`):
|
| 148 |
+
Clip the predicted sample for numerical stability.
|
| 149 |
+
clip_sample_range (`float`, defaults to 1.0):
|
| 150 |
+
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
| 151 |
+
set_alpha_to_one (`bool`, defaults to `True`):
|
| 152 |
+
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
| 153 |
+
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
| 154 |
+
otherwise it uses the alpha value at step 0.
|
| 155 |
+
steps_offset (`int`, defaults to 0):
|
| 156 |
+
An offset added to the inference steps, as required by some model families.
|
| 157 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
| 158 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
| 159 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
| 160 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
| 161 |
+
thresholding (`bool`, defaults to `False`):
|
| 162 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
| 163 |
+
as Stable Diffusion.
|
| 164 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 165 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 166 |
+
sample_max_value (`float`, defaults to 1.0):
|
| 167 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
| 168 |
+
timestep_spacing (`str`, defaults to `"leading"`):
|
| 169 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 170 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 171 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
| 172 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
| 173 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
| 174 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 178 |
+
order = 1
|
| 179 |
+
|
| 180 |
+
@register_to_config
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
num_train_timesteps: int = 1000,
|
| 184 |
+
beta_start: float = 0.00085,
|
| 185 |
+
beta_end: float = 0.0120,
|
| 186 |
+
beta_schedule: str = "scaled_linear",
|
| 187 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
| 188 |
+
clip_sample: bool = True,
|
| 189 |
+
set_alpha_to_one: bool = True,
|
| 190 |
+
steps_offset: int = 0,
|
| 191 |
+
prediction_type: str = "epsilon",
|
| 192 |
+
clip_sample_range: float = 1.0,
|
| 193 |
+
sample_max_value: float = 1.0,
|
| 194 |
+
timestep_spacing: str = "leading",
|
| 195 |
+
rescale_betas_zero_snr: bool = False,
|
| 196 |
+
snr_shift_scale: float = 3.0,
|
| 197 |
+
):
|
| 198 |
+
if trained_betas is not None:
|
| 199 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
| 200 |
+
elif beta_schedule == "linear":
|
| 201 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
| 202 |
+
elif beta_schedule == "scaled_linear":
|
| 203 |
+
# this schedule is very specific to the latent diffusion model.
|
| 204 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
|
| 205 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
| 206 |
+
# Glide cosine schedule
|
| 207 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
| 208 |
+
else:
|
| 209 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
| 210 |
+
|
| 211 |
+
self.alphas = 1.0 - self.betas
|
| 212 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 213 |
+
|
| 214 |
+
# Modify: SNR shift following SD3
|
| 215 |
+
self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
|
| 216 |
+
|
| 217 |
+
# Rescale for zero SNR
|
| 218 |
+
if rescale_betas_zero_snr:
|
| 219 |
+
self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
|
| 220 |
+
|
| 221 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
| 222 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
| 223 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
| 224 |
+
# whether we use the final alpha of the "non-previous" one.
|
| 225 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
| 226 |
+
|
| 227 |
+
# standard deviation of the initial noise distribution
|
| 228 |
+
self.init_noise_sigma = 1.0
|
| 229 |
+
|
| 230 |
+
# setable values
|
| 231 |
+
self.num_inference_steps = None
|
| 232 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
| 233 |
+
|
| 234 |
+
def _get_variance(self, timestep, prev_timestep):
|
| 235 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
| 236 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
| 237 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 238 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
| 239 |
+
|
| 240 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
| 241 |
+
|
| 242 |
+
return variance
|
| 243 |
+
|
| 244 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
| 245 |
+
"""
|
| 246 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 247 |
+
current timestep.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
sample (`torch.Tensor`):
|
| 251 |
+
The input sample.
|
| 252 |
+
timestep (`int`, *optional*):
|
| 253 |
+
The current timestep in the diffusion chain.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
`torch.Tensor`:
|
| 257 |
+
A scaled input sample.
|
| 258 |
+
"""
|
| 259 |
+
return sample
|
| 260 |
+
|
| 261 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
| 262 |
+
"""
|
| 263 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
num_inference_steps (`int`):
|
| 267 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
if num_inference_steps > self.config.num_train_timesteps:
|
| 271 |
+
raise ValueError(
|
| 272 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
| 273 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
| 274 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self.num_inference_steps = num_inference_steps
|
| 278 |
+
|
| 279 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
| 280 |
+
if self.config.timestep_spacing == "linspace":
|
| 281 |
+
timesteps = (
|
| 282 |
+
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
|
| 283 |
+
.round()[::-1]
|
| 284 |
+
.copy()
|
| 285 |
+
.astype(np.int64)
|
| 286 |
+
)
|
| 287 |
+
elif self.config.timestep_spacing == "leading":
|
| 288 |
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
| 289 |
+
# creates integer timesteps by multiplying by ratio
|
| 290 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
| 291 |
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
| 292 |
+
timesteps += self.config.steps_offset
|
| 293 |
+
elif self.config.timestep_spacing == "trailing":
|
| 294 |
+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
| 295 |
+
# creates integer timesteps by multiplying by ratio
|
| 296 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
| 297 |
+
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
|
| 298 |
+
timesteps -= 1
|
| 299 |
+
else:
|
| 300 |
+
raise ValueError(
|
| 301 |
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
self.timesteps = torch.from_numpy(timesteps).to(device)
|
| 305 |
+
|
| 306 |
+
def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
|
| 307 |
+
lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
|
| 308 |
+
lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
|
| 309 |
+
h = lamb_next - lamb
|
| 310 |
+
|
| 311 |
+
if alpha_prod_t_back is not None:
|
| 312 |
+
lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log()
|
| 313 |
+
h_last = lamb - lamb_previous
|
| 314 |
+
r = h_last / h
|
| 315 |
+
return h, r, lamb, lamb_next
|
| 316 |
+
else:
|
| 317 |
+
return h, None, lamb, lamb_next
|
| 318 |
+
|
| 319 |
+
def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
|
| 320 |
+
mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
|
| 321 |
+
mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5
|
| 322 |
+
|
| 323 |
+
if alpha_prod_t_back is not None:
|
| 324 |
+
mult3 = 1 + 1 / (2 * r)
|
| 325 |
+
mult4 = 1 / (2 * r)
|
| 326 |
+
return mult1, mult2, mult3, mult4
|
| 327 |
+
else:
|
| 328 |
+
return mult1, mult2
|
| 329 |
+
|
| 330 |
+
def step(
|
| 331 |
+
self,
|
| 332 |
+
model_output: torch.Tensor,
|
| 333 |
+
old_pred_original_sample: torch.Tensor,
|
| 334 |
+
timestep: int,
|
| 335 |
+
timestep_back: int,
|
| 336 |
+
sample: torch.Tensor,
|
| 337 |
+
eta: float = 0.0,
|
| 338 |
+
use_clipped_model_output: bool = False,
|
| 339 |
+
generator=None,
|
| 340 |
+
variance_noise: Optional[torch.Tensor] = None,
|
| 341 |
+
return_dict: bool = False,
|
| 342 |
+
) -> Union[DDIMSchedulerOutput, Tuple]:
|
| 343 |
+
"""
|
| 344 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 345 |
+
process from the learned model outputs (most often the predicted noise).
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
model_output (`torch.Tensor`):
|
| 349 |
+
The direct output from learned diffusion model.
|
| 350 |
+
timestep (`float`):
|
| 351 |
+
The current discrete timestep in the diffusion chain.
|
| 352 |
+
sample (`torch.Tensor`):
|
| 353 |
+
A current instance of a sample created by the diffusion process.
|
| 354 |
+
eta (`float`):
|
| 355 |
+
The weight of noise for added noise in diffusion step.
|
| 356 |
+
use_clipped_model_output (`bool`, defaults to `False`):
|
| 357 |
+
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
| 358 |
+
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
| 359 |
+
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
| 360 |
+
`use_clipped_model_output` has no effect.
|
| 361 |
+
generator (`torch.Generator`, *optional*):
|
| 362 |
+
A random number generator.
|
| 363 |
+
variance_noise (`torch.Tensor`):
|
| 364 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
| 365 |
+
itself. Useful for methods such as [`CycleDiffusion`].
|
| 366 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 367 |
+
Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
[`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
|
| 371 |
+
If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
|
| 372 |
+
tuple is returned where the first element is the sample tensor.
|
| 373 |
+
|
| 374 |
+
"""
|
| 375 |
+
if self.num_inference_steps is None:
|
| 376 |
+
raise ValueError(
|
| 377 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
|
| 381 |
+
# Ideally, read DDIM paper in-detail understanding
|
| 382 |
+
|
| 383 |
+
# Notation (<variable name> -> <name in paper>
|
| 384 |
+
# - pred_noise_t -> e_theta(x_t, t)
|
| 385 |
+
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
| 386 |
+
# - std_dev_t -> sigma_t
|
| 387 |
+
# - eta -> η
|
| 388 |
+
# - pred_sample_direction -> "direction pointing to x_t"
|
| 389 |
+
# - pred_prev_sample -> "x_t-1"
|
| 390 |
+
|
| 391 |
+
# 1. get previous step value (=t-1)
|
| 392 |
+
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
| 393 |
+
|
| 394 |
+
# 2. compute alphas, betas
|
| 395 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
| 396 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
| 397 |
+
alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None
|
| 398 |
+
|
| 399 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 400 |
+
|
| 401 |
+
# 3. compute predicted original sample from predicted noise also called
|
| 402 |
+
# "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
|
| 403 |
+
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
|
| 404 |
+
if self.config.prediction_type == "epsilon":
|
| 405 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
| 406 |
+
# pred_epsilon = model_output
|
| 407 |
+
elif self.config.prediction_type == "sample":
|
| 408 |
+
pred_original_sample = model_output
|
| 409 |
+
# pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
| 410 |
+
elif self.config.prediction_type == "v_prediction":
|
| 411 |
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
| 412 |
+
# pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
| 413 |
+
else:
|
| 414 |
+
raise ValueError(
|
| 415 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
| 416 |
+
" `v_prediction`"
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)
|
| 420 |
+
mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back))
|
| 421 |
+
mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5
|
| 422 |
+
|
| 423 |
+
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
| 424 |
+
prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * noise
|
| 425 |
+
|
| 426 |
+
if old_pred_original_sample is None or prev_timestep < 0:
|
| 427 |
+
# Save a network evaluation if all noise levels are 0 or on the first step
|
| 428 |
+
return prev_sample, pred_original_sample
|
| 429 |
+
else:
|
| 430 |
+
denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample
|
| 431 |
+
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
| 432 |
+
x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise
|
| 433 |
+
|
| 434 |
+
prev_sample = x_advanced
|
| 435 |
+
|
| 436 |
+
if not return_dict:
|
| 437 |
+
return (prev_sample, pred_original_sample)
|
| 438 |
+
|
| 439 |
+
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
| 440 |
+
|
| 441 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
| 442 |
+
def add_noise(
|
| 443 |
+
self,
|
| 444 |
+
original_samples: torch.Tensor,
|
| 445 |
+
noise: torch.Tensor,
|
| 446 |
+
timesteps: torch.IntTensor,
|
| 447 |
+
) -> torch.Tensor:
|
| 448 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
| 449 |
+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
| 450 |
+
# for the subsequent add_noise calls
|
| 451 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
|
| 452 |
+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
|
| 453 |
+
timesteps = timesteps.to(original_samples.device)
|
| 454 |
+
|
| 455 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
| 456 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
| 457 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
| 458 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
| 459 |
+
|
| 460 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
| 461 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
| 462 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
| 463 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
| 464 |
+
|
| 465 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
| 466 |
+
return noisy_samples
|
| 467 |
+
|
| 468 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
| 469 |
+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
| 470 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
| 471 |
+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
| 472 |
+
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
| 473 |
+
timesteps = timesteps.to(sample.device)
|
| 474 |
+
|
| 475 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
| 476 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
| 477 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
| 478 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
| 479 |
+
|
| 480 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
| 481 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
| 482 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
| 483 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
| 484 |
+
|
| 485 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
| 486 |
+
return velocity
|
| 487 |
+
|
| 488 |
+
def __len__(self):
|
| 489 |
+
return self.config.num_train_timesteps
|
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
ADDED
|
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from ..utils.torch_utils import randn_tensor
|
| 25 |
+
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
| 29 |
+
"""
|
| 30 |
+
Implements DPMSolverMultistepScheduler in EDM formulation as presented in Karras et al. 2022 [1].
|
| 31 |
+
`EDMDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
|
| 32 |
+
|
| 33 |
+
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
|
| 34 |
+
https://huggingface.co/papers/2206.00364
|
| 35 |
+
|
| 36 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 37 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
sigma_min (`float`, *optional*, defaults to 0.002):
|
| 41 |
+
Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable
|
| 42 |
+
range is [0, 10].
|
| 43 |
+
sigma_max (`float`, *optional*, defaults to 80.0):
|
| 44 |
+
Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable
|
| 45 |
+
range is [0.2, 80.0].
|
| 46 |
+
sigma_data (`float`, *optional*, defaults to 0.5):
|
| 47 |
+
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
|
| 48 |
+
sigma_schedule (`str`, *optional*, defaults to `karras`):
|
| 49 |
+
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
|
| 50 |
+
(https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
|
| 51 |
+
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
| 52 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 53 |
+
The number of diffusion steps to train the model.
|
| 54 |
+
solver_order (`int`, defaults to 2):
|
| 55 |
+
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
| 56 |
+
sampling, and `solver_order=3` for unconditional sampling.
|
| 57 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
| 58 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
| 59 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
| 60 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
| 61 |
+
thresholding (`bool`, defaults to `False`):
|
| 62 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
| 63 |
+
as Stable Diffusion.
|
| 64 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 65 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 66 |
+
sample_max_value (`float`, defaults to 1.0):
|
| 67 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
| 68 |
+
`algorithm_type="dpmsolver++"`.
|
| 69 |
+
algorithm_type (`str`, defaults to `dpmsolver++`):
|
| 70 |
+
Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver++` type implements
|
| 71 |
+
the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to
|
| 72 |
+
use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
| 73 |
+
solver_type (`str`, defaults to `midpoint`):
|
| 74 |
+
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
| 75 |
+
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
| 76 |
+
lower_order_final (`bool`, defaults to `True`):
|
| 77 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
| 78 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
| 79 |
+
euler_at_final (`bool`, defaults to `False`):
|
| 80 |
+
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
|
| 81 |
+
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
|
| 82 |
+
steps, but sometimes may result in blurring.
|
| 83 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
| 84 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 85 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
_compatibles = []
|
| 89 |
+
order = 1
|
| 90 |
+
|
| 91 |
+
@register_to_config
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
sigma_min: float = 0.002,
|
| 95 |
+
sigma_max: float = 80.0,
|
| 96 |
+
sigma_data: float = 0.5,
|
| 97 |
+
sigma_schedule: str = "karras",
|
| 98 |
+
num_train_timesteps: int = 1000,
|
| 99 |
+
prediction_type: str = "epsilon",
|
| 100 |
+
rho: float = 7.0,
|
| 101 |
+
solver_order: int = 2,
|
| 102 |
+
thresholding: bool = False,
|
| 103 |
+
dynamic_thresholding_ratio: float = 0.995,
|
| 104 |
+
sample_max_value: float = 1.0,
|
| 105 |
+
algorithm_type: str = "dpmsolver++",
|
| 106 |
+
solver_type: str = "midpoint",
|
| 107 |
+
lower_order_final: bool = True,
|
| 108 |
+
euler_at_final: bool = False,
|
| 109 |
+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
| 110 |
+
):
|
| 111 |
+
# settings for DPM-Solver
|
| 112 |
+
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]:
|
| 113 |
+
if algorithm_type == "deis":
|
| 114 |
+
self.register_to_config(algorithm_type="dpmsolver++")
|
| 115 |
+
else:
|
| 116 |
+
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
| 117 |
+
|
| 118 |
+
if solver_type not in ["midpoint", "heun"]:
|
| 119 |
+
if solver_type in ["logrho", "bh1", "bh2"]:
|
| 120 |
+
self.register_to_config(solver_type="midpoint")
|
| 121 |
+
else:
|
| 122 |
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
| 123 |
+
|
| 124 |
+
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
ramp = torch.linspace(0, 1, num_train_timesteps)
|
| 130 |
+
if sigma_schedule == "karras":
|
| 131 |
+
sigmas = self._compute_karras_sigmas(ramp)
|
| 132 |
+
elif sigma_schedule == "exponential":
|
| 133 |
+
sigmas = self._compute_exponential_sigmas(ramp)
|
| 134 |
+
|
| 135 |
+
self.timesteps = self.precondition_noise(sigmas)
|
| 136 |
+
|
| 137 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
| 138 |
+
|
| 139 |
+
# setable values
|
| 140 |
+
self.num_inference_steps = None
|
| 141 |
+
self.model_outputs = [None] * solver_order
|
| 142 |
+
self.lower_order_nums = 0
|
| 143 |
+
self._step_index = None
|
| 144 |
+
self._begin_index = None
|
| 145 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def init_noise_sigma(self):
|
| 149 |
+
# standard deviation of the initial noise distribution
|
| 150 |
+
return (self.config.sigma_max**2 + 1) ** 0.5
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def step_index(self):
|
| 154 |
+
"""
|
| 155 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 156 |
+
"""
|
| 157 |
+
return self._step_index
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
def begin_index(self):
|
| 161 |
+
"""
|
| 162 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 163 |
+
"""
|
| 164 |
+
return self._begin_index
|
| 165 |
+
|
| 166 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 167 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 168 |
+
"""
|
| 169 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
begin_index (`int`):
|
| 173 |
+
The begin index for the scheduler.
|
| 174 |
+
"""
|
| 175 |
+
self._begin_index = begin_index
|
| 176 |
+
|
| 177 |
+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
|
| 178 |
+
def precondition_inputs(self, sample, sigma):
|
| 179 |
+
c_in = self._get_conditioning_c_in(sigma)
|
| 180 |
+
scaled_sample = sample * c_in
|
| 181 |
+
return scaled_sample
|
| 182 |
+
|
| 183 |
+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise
|
| 184 |
+
def precondition_noise(self, sigma):
|
| 185 |
+
if not isinstance(sigma, torch.Tensor):
|
| 186 |
+
sigma = torch.tensor([sigma])
|
| 187 |
+
|
| 188 |
+
c_noise = 0.25 * torch.log(sigma)
|
| 189 |
+
|
| 190 |
+
return c_noise
|
| 191 |
+
|
| 192 |
+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
|
| 193 |
+
def precondition_outputs(self, sample, model_output, sigma):
|
| 194 |
+
sigma_data = self.config.sigma_data
|
| 195 |
+
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
| 196 |
+
|
| 197 |
+
if self.config.prediction_type == "epsilon":
|
| 198 |
+
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
| 199 |
+
elif self.config.prediction_type == "v_prediction":
|
| 200 |
+
c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
|
| 203 |
+
|
| 204 |
+
denoised = c_skip * sample + c_out * model_output
|
| 205 |
+
|
| 206 |
+
return denoised
|
| 207 |
+
|
| 208 |
+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
|
| 209 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 210 |
+
"""
|
| 211 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 212 |
+
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
sample (`torch.Tensor`):
|
| 216 |
+
The input sample.
|
| 217 |
+
timestep (`int`, *optional*):
|
| 218 |
+
The current timestep in the diffusion chain.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
`torch.Tensor`:
|
| 222 |
+
A scaled input sample.
|
| 223 |
+
"""
|
| 224 |
+
if self.step_index is None:
|
| 225 |
+
self._init_step_index(timestep)
|
| 226 |
+
|
| 227 |
+
sigma = self.sigmas[self.step_index]
|
| 228 |
+
sample = self.precondition_inputs(sample, sigma)
|
| 229 |
+
|
| 230 |
+
self.is_scale_input_called = True
|
| 231 |
+
return sample
|
| 232 |
+
|
| 233 |
+
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
| 234 |
+
"""
|
| 235 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
num_inference_steps (`int`):
|
| 239 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 240 |
+
device (`str` or `torch.device`, *optional*):
|
| 241 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
self.num_inference_steps = num_inference_steps
|
| 245 |
+
|
| 246 |
+
ramp = torch.linspace(0, 1, self.num_inference_steps)
|
| 247 |
+
if self.config.sigma_schedule == "karras":
|
| 248 |
+
sigmas = self._compute_karras_sigmas(ramp)
|
| 249 |
+
elif self.config.sigma_schedule == "exponential":
|
| 250 |
+
sigmas = self._compute_exponential_sigmas(ramp)
|
| 251 |
+
|
| 252 |
+
sigmas = sigmas.to(dtype=torch.float32, device=device)
|
| 253 |
+
self.timesteps = self.precondition_noise(sigmas)
|
| 254 |
+
|
| 255 |
+
if self.config.final_sigmas_type == "sigma_min":
|
| 256 |
+
sigma_last = self.config.sigma_min
|
| 257 |
+
elif self.config.final_sigmas_type == "zero":
|
| 258 |
+
sigma_last = 0
|
| 259 |
+
else:
|
| 260 |
+
raise ValueError(
|
| 261 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)])
|
| 265 |
+
|
| 266 |
+
self.model_outputs = [
|
| 267 |
+
None,
|
| 268 |
+
] * self.config.solver_order
|
| 269 |
+
self.lower_order_nums = 0
|
| 270 |
+
|
| 271 |
+
# add an index counter for schedulers that allow duplicated timesteps
|
| 272 |
+
self._step_index = None
|
| 273 |
+
self._begin_index = None
|
| 274 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 275 |
+
|
| 276 |
+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
|
| 277 |
+
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
| 278 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 279 |
+
sigma_min = sigma_min or self.config.sigma_min
|
| 280 |
+
sigma_max = sigma_max or self.config.sigma_max
|
| 281 |
+
|
| 282 |
+
rho = self.config.rho
|
| 283 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 284 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 285 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 286 |
+
return sigmas
|
| 287 |
+
|
| 288 |
+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
|
| 289 |
+
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
| 290 |
+
"""Implementation closely follows k-diffusion.
|
| 291 |
+
|
| 292 |
+
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
| 293 |
+
"""
|
| 294 |
+
sigma_min = sigma_min or self.config.sigma_min
|
| 295 |
+
sigma_max = sigma_max or self.config.sigma_max
|
| 296 |
+
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
|
| 297 |
+
return sigmas
|
| 298 |
+
|
| 299 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
| 300 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
| 301 |
+
"""
|
| 302 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
| 303 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 304 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
| 305 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
| 306 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
| 307 |
+
|
| 308 |
+
https://huggingface.co/papers/2205.11487
|
| 309 |
+
"""
|
| 310 |
+
dtype = sample.dtype
|
| 311 |
+
batch_size, channels, *remaining_dims = sample.shape
|
| 312 |
+
|
| 313 |
+
if dtype not in (torch.float32, torch.float64):
|
| 314 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 315 |
+
|
| 316 |
+
# Flatten sample for doing quantile calculation along each image
|
| 317 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
| 318 |
+
|
| 319 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 320 |
+
|
| 321 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 322 |
+
s = torch.clamp(
|
| 323 |
+
s, min=1, max=self.config.sample_max_value
|
| 324 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 325 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 326 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 327 |
+
|
| 328 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 329 |
+
sample = sample.to(dtype)
|
| 330 |
+
|
| 331 |
+
return sample
|
| 332 |
+
|
| 333 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
| 334 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
| 335 |
+
# get log sigma
|
| 336 |
+
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
| 337 |
+
|
| 338 |
+
# get distribution
|
| 339 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
| 340 |
+
|
| 341 |
+
# get sigmas range
|
| 342 |
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
| 343 |
+
high_idx = low_idx + 1
|
| 344 |
+
|
| 345 |
+
low = log_sigmas[low_idx]
|
| 346 |
+
high = log_sigmas[high_idx]
|
| 347 |
+
|
| 348 |
+
# interpolate sigmas
|
| 349 |
+
w = (low - log_sigma) / (low - high)
|
| 350 |
+
w = np.clip(w, 0, 1)
|
| 351 |
+
|
| 352 |
+
# transform interpolation to time range
|
| 353 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 354 |
+
t = t.reshape(sigma.shape)
|
| 355 |
+
return t
|
| 356 |
+
|
| 357 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
| 358 |
+
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
|
| 359 |
+
sigma_t = sigma
|
| 360 |
+
|
| 361 |
+
return alpha_t, sigma_t
|
| 362 |
+
|
| 363 |
+
def convert_model_output(
|
| 364 |
+
self,
|
| 365 |
+
model_output: torch.Tensor,
|
| 366 |
+
sample: torch.Tensor = None,
|
| 367 |
+
) -> torch.Tensor:
|
| 368 |
+
"""
|
| 369 |
+
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
| 370 |
+
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
| 371 |
+
integral of the data prediction model.
|
| 372 |
+
|
| 373 |
+
<Tip>
|
| 374 |
+
|
| 375 |
+
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
|
| 376 |
+
prediction and data prediction models.
|
| 377 |
+
|
| 378 |
+
</Tip>
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
model_output (`torch.Tensor`):
|
| 382 |
+
The direct output from the learned diffusion model.
|
| 383 |
+
sample (`torch.Tensor`):
|
| 384 |
+
A current instance of a sample created by the diffusion process.
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
`torch.Tensor`:
|
| 388 |
+
The converted model output.
|
| 389 |
+
"""
|
| 390 |
+
sigma = self.sigmas[self.step_index]
|
| 391 |
+
x0_pred = self.precondition_outputs(sample, model_output, sigma)
|
| 392 |
+
|
| 393 |
+
if self.config.thresholding:
|
| 394 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 395 |
+
|
| 396 |
+
return x0_pred
|
| 397 |
+
|
| 398 |
+
def dpm_solver_first_order_update(
|
| 399 |
+
self,
|
| 400 |
+
model_output: torch.Tensor,
|
| 401 |
+
sample: torch.Tensor = None,
|
| 402 |
+
noise: Optional[torch.Tensor] = None,
|
| 403 |
+
) -> torch.Tensor:
|
| 404 |
+
"""
|
| 405 |
+
One step for the first-order DPMSolver (equivalent to DDIM).
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
model_output (`torch.Tensor`):
|
| 409 |
+
The direct output from the learned diffusion model.
|
| 410 |
+
sample (`torch.Tensor`):
|
| 411 |
+
A current instance of a sample created by the diffusion process.
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
`torch.Tensor`:
|
| 415 |
+
The sample tensor at the previous timestep.
|
| 416 |
+
"""
|
| 417 |
+
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
| 418 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 419 |
+
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
| 420 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 421 |
+
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
| 422 |
+
|
| 423 |
+
h = lambda_t - lambda_s
|
| 424 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 425 |
+
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
| 426 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
| 427 |
+
assert noise is not None
|
| 428 |
+
x_t = (
|
| 429 |
+
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
| 430 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
| 431 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
return x_t
|
| 435 |
+
|
| 436 |
+
def multistep_dpm_solver_second_order_update(
|
| 437 |
+
self,
|
| 438 |
+
model_output_list: List[torch.Tensor],
|
| 439 |
+
sample: torch.Tensor = None,
|
| 440 |
+
noise: Optional[torch.Tensor] = None,
|
| 441 |
+
) -> torch.Tensor:
|
| 442 |
+
"""
|
| 443 |
+
One step for the second-order multistep DPMSolver.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
model_output_list (`List[torch.Tensor]`):
|
| 447 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
| 448 |
+
sample (`torch.Tensor`):
|
| 449 |
+
A current instance of a sample created by the diffusion process.
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
`torch.Tensor`:
|
| 453 |
+
The sample tensor at the previous timestep.
|
| 454 |
+
"""
|
| 455 |
+
sigma_t, sigma_s0, sigma_s1 = (
|
| 456 |
+
self.sigmas[self.step_index + 1],
|
| 457 |
+
self.sigmas[self.step_index],
|
| 458 |
+
self.sigmas[self.step_index - 1],
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 462 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 463 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
| 464 |
+
|
| 465 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 466 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 467 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
| 468 |
+
|
| 469 |
+
m0, m1 = model_output_list[-1], model_output_list[-2]
|
| 470 |
+
|
| 471 |
+
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
| 472 |
+
r0 = h_0 / h
|
| 473 |
+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
| 474 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 475 |
+
# See https://huggingface.co/papers/2211.01095 for detailed derivations
|
| 476 |
+
if self.config.solver_type == "midpoint":
|
| 477 |
+
x_t = (
|
| 478 |
+
(sigma_t / sigma_s0) * sample
|
| 479 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 480 |
+
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
| 481 |
+
)
|
| 482 |
+
elif self.config.solver_type == "heun":
|
| 483 |
+
x_t = (
|
| 484 |
+
(sigma_t / sigma_s0) * sample
|
| 485 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 486 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
| 487 |
+
)
|
| 488 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
| 489 |
+
assert noise is not None
|
| 490 |
+
if self.config.solver_type == "midpoint":
|
| 491 |
+
x_t = (
|
| 492 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
| 493 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
| 494 |
+
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
| 495 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 496 |
+
)
|
| 497 |
+
elif self.config.solver_type == "heun":
|
| 498 |
+
x_t = (
|
| 499 |
+
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
| 500 |
+
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
| 501 |
+
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
| 502 |
+
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
return x_t
|
| 506 |
+
|
| 507 |
+
def multistep_dpm_solver_third_order_update(
|
| 508 |
+
self,
|
| 509 |
+
model_output_list: List[torch.Tensor],
|
| 510 |
+
sample: torch.Tensor = None,
|
| 511 |
+
) -> torch.Tensor:
|
| 512 |
+
"""
|
| 513 |
+
One step for the third-order multistep DPMSolver.
|
| 514 |
+
|
| 515 |
+
Args:
|
| 516 |
+
model_output_list (`List[torch.Tensor]`):
|
| 517 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
| 518 |
+
sample (`torch.Tensor`):
|
| 519 |
+
A current instance of a sample created by diffusion process.
|
| 520 |
+
|
| 521 |
+
Returns:
|
| 522 |
+
`torch.Tensor`:
|
| 523 |
+
The sample tensor at the previous timestep.
|
| 524 |
+
"""
|
| 525 |
+
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
| 526 |
+
self.sigmas[self.step_index + 1],
|
| 527 |
+
self.sigmas[self.step_index],
|
| 528 |
+
self.sigmas[self.step_index - 1],
|
| 529 |
+
self.sigmas[self.step_index - 2],
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 533 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 534 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
| 535 |
+
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
| 536 |
+
|
| 537 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 538 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 539 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
| 540 |
+
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
| 541 |
+
|
| 542 |
+
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
|
| 543 |
+
|
| 544 |
+
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
| 545 |
+
r0, r1 = h_0 / h, h_1 / h
|
| 546 |
+
D0 = m0
|
| 547 |
+
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
| 548 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 549 |
+
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 550 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 551 |
+
# See https://huggingface.co/papers/2206.00927 for detailed derivations
|
| 552 |
+
x_t = (
|
| 553 |
+
(sigma_t / sigma_s0) * sample
|
| 554 |
+
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
| 555 |
+
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
| 556 |
+
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
return x_t
|
| 560 |
+
|
| 561 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
| 562 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 563 |
+
if schedule_timesteps is None:
|
| 564 |
+
schedule_timesteps = self.timesteps
|
| 565 |
+
|
| 566 |
+
index_candidates = (schedule_timesteps == timestep).nonzero()
|
| 567 |
+
|
| 568 |
+
if len(index_candidates) == 0:
|
| 569 |
+
step_index = len(self.timesteps) - 1
|
| 570 |
+
# The sigma index that is taken for the **very** first `step`
|
| 571 |
+
# is always the second index (or the last index if there is only 1)
|
| 572 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 573 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 574 |
+
elif len(index_candidates) > 1:
|
| 575 |
+
step_index = index_candidates[1].item()
|
| 576 |
+
else:
|
| 577 |
+
step_index = index_candidates[0].item()
|
| 578 |
+
|
| 579 |
+
return step_index
|
| 580 |
+
|
| 581 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
| 582 |
+
def _init_step_index(self, timestep):
|
| 583 |
+
"""
|
| 584 |
+
Initialize the step_index counter for the scheduler.
|
| 585 |
+
"""
|
| 586 |
+
|
| 587 |
+
if self.begin_index is None:
|
| 588 |
+
if isinstance(timestep, torch.Tensor):
|
| 589 |
+
timestep = timestep.to(self.timesteps.device)
|
| 590 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 591 |
+
else:
|
| 592 |
+
self._step_index = self._begin_index
|
| 593 |
+
|
| 594 |
+
def step(
|
| 595 |
+
self,
|
| 596 |
+
model_output: torch.Tensor,
|
| 597 |
+
timestep: Union[int, torch.Tensor],
|
| 598 |
+
sample: torch.Tensor,
|
| 599 |
+
generator=None,
|
| 600 |
+
return_dict: bool = True,
|
| 601 |
+
) -> Union[SchedulerOutput, Tuple]:
|
| 602 |
+
"""
|
| 603 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 604 |
+
the multistep DPMSolver.
|
| 605 |
+
|
| 606 |
+
Args:
|
| 607 |
+
model_output (`torch.Tensor`):
|
| 608 |
+
The direct output from learned diffusion model.
|
| 609 |
+
timestep (`int`):
|
| 610 |
+
The current discrete timestep in the diffusion chain.
|
| 611 |
+
sample (`torch.Tensor`):
|
| 612 |
+
A current instance of a sample created by the diffusion process.
|
| 613 |
+
generator (`torch.Generator`, *optional*):
|
| 614 |
+
A random number generator.
|
| 615 |
+
return_dict (`bool`):
|
| 616 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
| 617 |
+
|
| 618 |
+
Returns:
|
| 619 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 620 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 621 |
+
tuple is returned where the first element is the sample tensor.
|
| 622 |
+
|
| 623 |
+
"""
|
| 624 |
+
if self.num_inference_steps is None:
|
| 625 |
+
raise ValueError(
|
| 626 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
if self.step_index is None:
|
| 630 |
+
self._init_step_index(timestep)
|
| 631 |
+
|
| 632 |
+
# Improve numerical stability for small number of steps
|
| 633 |
+
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
|
| 634 |
+
self.config.euler_at_final
|
| 635 |
+
or (self.config.lower_order_final and len(self.timesteps) < 15)
|
| 636 |
+
or self.config.final_sigmas_type == "zero"
|
| 637 |
+
)
|
| 638 |
+
lower_order_second = (
|
| 639 |
+
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
model_output = self.convert_model_output(model_output, sample=sample)
|
| 643 |
+
for i in range(self.config.solver_order - 1):
|
| 644 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
| 645 |
+
self.model_outputs[-1] = model_output
|
| 646 |
+
|
| 647 |
+
if self.config.algorithm_type == "sde-dpmsolver++":
|
| 648 |
+
noise = randn_tensor(
|
| 649 |
+
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
| 650 |
+
)
|
| 651 |
+
else:
|
| 652 |
+
noise = None
|
| 653 |
+
|
| 654 |
+
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
| 655 |
+
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
|
| 656 |
+
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
| 657 |
+
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
|
| 658 |
+
else:
|
| 659 |
+
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
|
| 660 |
+
|
| 661 |
+
if self.lower_order_nums < self.config.solver_order:
|
| 662 |
+
self.lower_order_nums += 1
|
| 663 |
+
|
| 664 |
+
# upon completion increase step index by one
|
| 665 |
+
self._step_index += 1
|
| 666 |
+
|
| 667 |
+
if not return_dict:
|
| 668 |
+
return (prev_sample,)
|
| 669 |
+
|
| 670 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
| 671 |
+
|
| 672 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
| 673 |
+
def add_noise(
|
| 674 |
+
self,
|
| 675 |
+
original_samples: torch.Tensor,
|
| 676 |
+
noise: torch.Tensor,
|
| 677 |
+
timesteps: torch.Tensor,
|
| 678 |
+
) -> torch.Tensor:
|
| 679 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 680 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 681 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 682 |
+
# mps does not support float64
|
| 683 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
| 684 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
| 685 |
+
else:
|
| 686 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
| 687 |
+
timesteps = timesteps.to(original_samples.device)
|
| 688 |
+
|
| 689 |
+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
| 690 |
+
if self.begin_index is None:
|
| 691 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
| 692 |
+
elif self.step_index is not None:
|
| 693 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 694 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
| 695 |
+
else:
|
| 696 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 697 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
| 698 |
+
|
| 699 |
+
sigma = sigmas[step_indices].flatten()
|
| 700 |
+
while len(sigma.shape) < len(original_samples.shape):
|
| 701 |
+
sigma = sigma.unsqueeze(-1)
|
| 702 |
+
|
| 703 |
+
noisy_samples = original_samples + noise * sigma
|
| 704 |
+
return noisy_samples
|
| 705 |
+
|
| 706 |
+
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
| 707 |
+
def _get_conditioning_c_in(self, sigma):
|
| 708 |
+
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
| 709 |
+
return c_in
|
| 710 |
+
|
| 711 |
+
def __len__(self):
|
| 712 |
+
return self.config.num_train_timesteps
|
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_edm_euler.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
| 22 |
+
from ..utils import BaseOutput, logging
|
| 23 |
+
from ..utils.torch_utils import randn_tensor
|
| 24 |
+
from .scheduling_utils import SchedulerMixin
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
|
| 32 |
+
class EDMEulerSchedulerOutput(BaseOutput):
|
| 33 |
+
"""
|
| 34 |
+
Output class for the scheduler's `step` function output.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 38 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 39 |
+
denoising loop.
|
| 40 |
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 41 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
| 42 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
prev_sample: torch.Tensor
|
| 46 |
+
pred_original_sample: Optional[torch.Tensor] = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
| 50 |
+
"""
|
| 51 |
+
Implements the Euler scheduler in EDM formulation as presented in Karras et al. 2022 [1].
|
| 52 |
+
|
| 53 |
+
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
|
| 54 |
+
https://huggingface.co/papers/2206.00364
|
| 55 |
+
|
| 56 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 57 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
sigma_min (`float`, *optional*, defaults to 0.002):
|
| 61 |
+
Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable
|
| 62 |
+
range is [0, 10].
|
| 63 |
+
sigma_max (`float`, *optional*, defaults to 80.0):
|
| 64 |
+
Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable
|
| 65 |
+
range is [0.2, 80.0].
|
| 66 |
+
sigma_data (`float`, *optional*, defaults to 0.5):
|
| 67 |
+
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
|
| 68 |
+
sigma_schedule (`str`, *optional*, defaults to `karras`):
|
| 69 |
+
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
|
| 70 |
+
(https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
|
| 71 |
+
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
| 72 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 73 |
+
The number of diffusion steps to train the model.
|
| 74 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
| 75 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
| 76 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
| 77 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
| 78 |
+
rho (`float`, *optional*, defaults to 7.0):
|
| 79 |
+
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
|
| 80 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
| 81 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 82 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
_compatibles = []
|
| 86 |
+
order = 1
|
| 87 |
+
|
| 88 |
+
@register_to_config
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
sigma_min: float = 0.002,
|
| 92 |
+
sigma_max: float = 80.0,
|
| 93 |
+
sigma_data: float = 0.5,
|
| 94 |
+
sigma_schedule: str = "karras",
|
| 95 |
+
num_train_timesteps: int = 1000,
|
| 96 |
+
prediction_type: str = "epsilon",
|
| 97 |
+
rho: float = 7.0,
|
| 98 |
+
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
| 99 |
+
):
|
| 100 |
+
if sigma_schedule not in ["karras", "exponential"]:
|
| 101 |
+
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
|
| 102 |
+
|
| 103 |
+
# setable values
|
| 104 |
+
self.num_inference_steps = None
|
| 105 |
+
|
| 106 |
+
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 107 |
+
sigmas = torch.arange(num_train_timesteps + 1, dtype=sigmas_dtype) / num_train_timesteps
|
| 108 |
+
if sigma_schedule == "karras":
|
| 109 |
+
sigmas = self._compute_karras_sigmas(sigmas)
|
| 110 |
+
elif sigma_schedule == "exponential":
|
| 111 |
+
sigmas = self._compute_exponential_sigmas(sigmas)
|
| 112 |
+
sigmas = sigmas.to(torch.float32)
|
| 113 |
+
|
| 114 |
+
self.timesteps = self.precondition_noise(sigmas)
|
| 115 |
+
|
| 116 |
+
if self.config.final_sigmas_type == "sigma_min":
|
| 117 |
+
sigma_last = sigmas[-1]
|
| 118 |
+
elif self.config.final_sigmas_type == "zero":
|
| 119 |
+
sigma_last = 0
|
| 120 |
+
else:
|
| 121 |
+
raise ValueError(
|
| 122 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
|
| 126 |
+
|
| 127 |
+
self.is_scale_input_called = False
|
| 128 |
+
|
| 129 |
+
self._step_index = None
|
| 130 |
+
self._begin_index = None
|
| 131 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def init_noise_sigma(self):
|
| 135 |
+
# standard deviation of the initial noise distribution
|
| 136 |
+
return (self.config.sigma_max**2 + 1) ** 0.5
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def step_index(self):
|
| 140 |
+
"""
|
| 141 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 142 |
+
"""
|
| 143 |
+
return self._step_index
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def begin_index(self):
|
| 147 |
+
"""
|
| 148 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 149 |
+
"""
|
| 150 |
+
return self._begin_index
|
| 151 |
+
|
| 152 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 153 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 154 |
+
"""
|
| 155 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
begin_index (`int`):
|
| 159 |
+
The begin index for the scheduler.
|
| 160 |
+
"""
|
| 161 |
+
self._begin_index = begin_index
|
| 162 |
+
|
| 163 |
+
def precondition_inputs(self, sample, sigma):
|
| 164 |
+
c_in = self._get_conditioning_c_in(sigma)
|
| 165 |
+
scaled_sample = sample * c_in
|
| 166 |
+
return scaled_sample
|
| 167 |
+
|
| 168 |
+
def precondition_noise(self, sigma):
|
| 169 |
+
if not isinstance(sigma, torch.Tensor):
|
| 170 |
+
sigma = torch.tensor([sigma])
|
| 171 |
+
|
| 172 |
+
c_noise = 0.25 * torch.log(sigma)
|
| 173 |
+
|
| 174 |
+
return c_noise
|
| 175 |
+
|
| 176 |
+
def precondition_outputs(self, sample, model_output, sigma):
|
| 177 |
+
sigma_data = self.config.sigma_data
|
| 178 |
+
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
| 179 |
+
|
| 180 |
+
if self.config.prediction_type == "epsilon":
|
| 181 |
+
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
| 182 |
+
elif self.config.prediction_type == "v_prediction":
|
| 183 |
+
c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
|
| 186 |
+
|
| 187 |
+
denoised = c_skip * sample + c_out * model_output
|
| 188 |
+
|
| 189 |
+
return denoised
|
| 190 |
+
|
| 191 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 192 |
+
"""
|
| 193 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 194 |
+
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
sample (`torch.Tensor`):
|
| 198 |
+
The input sample.
|
| 199 |
+
timestep (`int`, *optional*):
|
| 200 |
+
The current timestep in the diffusion chain.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
`torch.Tensor`:
|
| 204 |
+
A scaled input sample.
|
| 205 |
+
"""
|
| 206 |
+
if self.step_index is None:
|
| 207 |
+
self._init_step_index(timestep)
|
| 208 |
+
|
| 209 |
+
sigma = self.sigmas[self.step_index]
|
| 210 |
+
sample = self.precondition_inputs(sample, sigma)
|
| 211 |
+
|
| 212 |
+
self.is_scale_input_called = True
|
| 213 |
+
return sample
|
| 214 |
+
|
| 215 |
+
def set_timesteps(
|
| 216 |
+
self,
|
| 217 |
+
num_inference_steps: int = None,
|
| 218 |
+
device: Union[str, torch.device] = None,
|
| 219 |
+
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
|
| 220 |
+
):
|
| 221 |
+
"""
|
| 222 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
num_inference_steps (`int`):
|
| 226 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 227 |
+
device (`str` or `torch.device`, *optional*):
|
| 228 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 229 |
+
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
|
| 230 |
+
Custom sigmas to use for the denoising process. If not defined, the default behavior when
|
| 231 |
+
`num_inference_steps` is passed will be used.
|
| 232 |
+
"""
|
| 233 |
+
self.num_inference_steps = num_inference_steps
|
| 234 |
+
|
| 235 |
+
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 236 |
+
if sigmas is None:
|
| 237 |
+
sigmas = torch.linspace(0, 1, self.num_inference_steps, dtype=sigmas_dtype)
|
| 238 |
+
elif isinstance(sigmas, float):
|
| 239 |
+
sigmas = torch.tensor(sigmas, dtype=sigmas_dtype)
|
| 240 |
+
else:
|
| 241 |
+
sigmas = sigmas.to(sigmas_dtype)
|
| 242 |
+
if self.config.sigma_schedule == "karras":
|
| 243 |
+
sigmas = self._compute_karras_sigmas(sigmas)
|
| 244 |
+
elif self.config.sigma_schedule == "exponential":
|
| 245 |
+
sigmas = self._compute_exponential_sigmas(sigmas)
|
| 246 |
+
sigmas = sigmas.to(dtype=torch.float32, device=device)
|
| 247 |
+
|
| 248 |
+
self.timesteps = self.precondition_noise(sigmas)
|
| 249 |
+
|
| 250 |
+
if self.config.final_sigmas_type == "sigma_min":
|
| 251 |
+
sigma_last = sigmas[-1]
|
| 252 |
+
elif self.config.final_sigmas_type == "zero":
|
| 253 |
+
sigma_last = 0
|
| 254 |
+
else:
|
| 255 |
+
raise ValueError(
|
| 256 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
|
| 260 |
+
self._step_index = None
|
| 261 |
+
self._begin_index = None
|
| 262 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 263 |
+
|
| 264 |
+
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
|
| 265 |
+
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
| 266 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 267 |
+
sigma_min = sigma_min or self.config.sigma_min
|
| 268 |
+
sigma_max = sigma_max or self.config.sigma_max
|
| 269 |
+
|
| 270 |
+
rho = self.config.rho
|
| 271 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 272 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 273 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 274 |
+
return sigmas
|
| 275 |
+
|
| 276 |
+
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
| 277 |
+
"""Implementation closely follows k-diffusion.
|
| 278 |
+
|
| 279 |
+
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
| 280 |
+
"""
|
| 281 |
+
sigma_min = sigma_min or self.config.sigma_min
|
| 282 |
+
sigma_max = sigma_max or self.config.sigma_max
|
| 283 |
+
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
|
| 284 |
+
return sigmas
|
| 285 |
+
|
| 286 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
| 287 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 288 |
+
if schedule_timesteps is None:
|
| 289 |
+
schedule_timesteps = self.timesteps
|
| 290 |
+
|
| 291 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 292 |
+
|
| 293 |
+
# The sigma index that is taken for the **very** first `step`
|
| 294 |
+
# is always the second index (or the last index if there is only 1)
|
| 295 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 296 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 297 |
+
pos = 1 if len(indices) > 1 else 0
|
| 298 |
+
|
| 299 |
+
return indices[pos].item()
|
| 300 |
+
|
| 301 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
| 302 |
+
def _init_step_index(self, timestep):
|
| 303 |
+
if self.begin_index is None:
|
| 304 |
+
if isinstance(timestep, torch.Tensor):
|
| 305 |
+
timestep = timestep.to(self.timesteps.device)
|
| 306 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 307 |
+
else:
|
| 308 |
+
self._step_index = self._begin_index
|
| 309 |
+
|
| 310 |
+
def step(
|
| 311 |
+
self,
|
| 312 |
+
model_output: torch.Tensor,
|
| 313 |
+
timestep: Union[float, torch.Tensor],
|
| 314 |
+
sample: torch.Tensor,
|
| 315 |
+
s_churn: float = 0.0,
|
| 316 |
+
s_tmin: float = 0.0,
|
| 317 |
+
s_tmax: float = float("inf"),
|
| 318 |
+
s_noise: float = 1.0,
|
| 319 |
+
generator: Optional[torch.Generator] = None,
|
| 320 |
+
return_dict: bool = True,
|
| 321 |
+
pred_original_sample: Optional[torch.Tensor] = None,
|
| 322 |
+
) -> Union[EDMEulerSchedulerOutput, Tuple]:
|
| 323 |
+
"""
|
| 324 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 325 |
+
process from the learned model outputs (most often the predicted noise).
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
model_output (`torch.Tensor`):
|
| 329 |
+
The direct output from learned diffusion model.
|
| 330 |
+
timestep (`float`):
|
| 331 |
+
The current discrete timestep in the diffusion chain.
|
| 332 |
+
sample (`torch.Tensor`):
|
| 333 |
+
A current instance of a sample created by the diffusion process.
|
| 334 |
+
s_churn (`float`):
|
| 335 |
+
s_tmin (`float`):
|
| 336 |
+
s_tmax (`float`):
|
| 337 |
+
s_noise (`float`, defaults to 1.0):
|
| 338 |
+
Scaling factor for noise added to the sample.
|
| 339 |
+
generator (`torch.Generator`, *optional*):
|
| 340 |
+
A random number generator.
|
| 341 |
+
return_dict (`bool`):
|
| 342 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple.
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
[`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`:
|
| 346 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is
|
| 347 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
| 351 |
+
raise ValueError(
|
| 352 |
+
(
|
| 353 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 354 |
+
" `EDMEulerScheduler.step()` is not supported. Make sure to pass"
|
| 355 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 356 |
+
),
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
if not self.is_scale_input_called:
|
| 360 |
+
logger.warning(
|
| 361 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
| 362 |
+
"See `StableDiffusionPipeline` for a usage example."
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
if self.step_index is None:
|
| 366 |
+
self._init_step_index(timestep)
|
| 367 |
+
|
| 368 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 369 |
+
sample = sample.to(torch.float32)
|
| 370 |
+
|
| 371 |
+
sigma = self.sigmas[self.step_index]
|
| 372 |
+
|
| 373 |
+
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
| 374 |
+
|
| 375 |
+
sigma_hat = sigma * (gamma + 1)
|
| 376 |
+
|
| 377 |
+
if gamma > 0:
|
| 378 |
+
noise = randn_tensor(
|
| 379 |
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
| 380 |
+
)
|
| 381 |
+
eps = noise * s_noise
|
| 382 |
+
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
| 383 |
+
|
| 384 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 385 |
+
if pred_original_sample is None:
|
| 386 |
+
pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
|
| 387 |
+
|
| 388 |
+
# 2. Convert to an ODE derivative
|
| 389 |
+
derivative = (sample - pred_original_sample) / sigma_hat
|
| 390 |
+
|
| 391 |
+
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
| 392 |
+
|
| 393 |
+
prev_sample = sample + derivative * dt
|
| 394 |
+
|
| 395 |
+
# Cast sample back to model compatible dtype
|
| 396 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 397 |
+
|
| 398 |
+
# upon completion increase step index by one
|
| 399 |
+
self._step_index += 1
|
| 400 |
+
|
| 401 |
+
if not return_dict:
|
| 402 |
+
return (
|
| 403 |
+
prev_sample,
|
| 404 |
+
pred_original_sample,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
| 408 |
+
|
| 409 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
| 410 |
+
def add_noise(
|
| 411 |
+
self,
|
| 412 |
+
original_samples: torch.Tensor,
|
| 413 |
+
noise: torch.Tensor,
|
| 414 |
+
timesteps: torch.Tensor,
|
| 415 |
+
) -> torch.Tensor:
|
| 416 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 417 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 418 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 419 |
+
# mps does not support float64
|
| 420 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
| 421 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
| 422 |
+
else:
|
| 423 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
| 424 |
+
timesteps = timesteps.to(original_samples.device)
|
| 425 |
+
|
| 426 |
+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
| 427 |
+
if self.begin_index is None:
|
| 428 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
| 429 |
+
elif self.step_index is not None:
|
| 430 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 431 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
| 432 |
+
else:
|
| 433 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 434 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
| 435 |
+
|
| 436 |
+
sigma = sigmas[step_indices].flatten()
|
| 437 |
+
while len(sigma.shape) < len(original_samples.shape):
|
| 438 |
+
sigma = sigma.unsqueeze(-1)
|
| 439 |
+
|
| 440 |
+
noisy_samples = original_samples + noise * sigma
|
| 441 |
+
return noisy_samples
|
| 442 |
+
|
| 443 |
+
def _get_conditioning_c_in(self, sigma):
|
| 444 |
+
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
| 445 |
+
return c_in
|
| 446 |
+
|
| 447 |
+
def __len__(self):
|
| 448 |
+
return self.config.num_train_timesteps
|
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from ..utils import BaseOutput, logging
|
| 24 |
+
from ..utils.torch_utils import randn_tensor
|
| 25 |
+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
|
| 33 |
+
class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
|
| 34 |
+
"""
|
| 35 |
+
Output class for the scheduler's `step` function output.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 39 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 40 |
+
denoising loop.
|
| 41 |
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 42 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
| 43 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
prev_sample: torch.Tensor
|
| 47 |
+
pred_original_sample: Optional[torch.Tensor] = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
| 51 |
+
def betas_for_alpha_bar(
|
| 52 |
+
num_diffusion_timesteps,
|
| 53 |
+
max_beta=0.999,
|
| 54 |
+
alpha_transform_type="cosine",
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
| 58 |
+
(1-beta) over time from t = [0,1].
|
| 59 |
+
|
| 60 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
| 61 |
+
to that part of the diffusion process.
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
| 66 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
| 67 |
+
prevent singularities.
|
| 68 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
| 69 |
+
Choose from `cosine` or `exp`
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
| 73 |
+
"""
|
| 74 |
+
if alpha_transform_type == "cosine":
|
| 75 |
+
|
| 76 |
+
def alpha_bar_fn(t):
|
| 77 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 78 |
+
|
| 79 |
+
elif alpha_transform_type == "exp":
|
| 80 |
+
|
| 81 |
+
def alpha_bar_fn(t):
|
| 82 |
+
return math.exp(t * -12.0)
|
| 83 |
+
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
| 86 |
+
|
| 87 |
+
betas = []
|
| 88 |
+
for i in range(num_diffusion_timesteps):
|
| 89 |
+
t1 = i / num_diffusion_timesteps
|
| 90 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 91 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
| 92 |
+
return torch.tensor(betas, dtype=torch.float32)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
| 96 |
+
def rescale_zero_terminal_snr(betas):
|
| 97 |
+
"""
|
| 98 |
+
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
betas (`torch.Tensor`):
|
| 103 |
+
the betas that the scheduler is being initialized with.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
| 107 |
+
"""
|
| 108 |
+
# Convert betas to alphas_bar_sqrt
|
| 109 |
+
alphas = 1.0 - betas
|
| 110 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 111 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 112 |
+
|
| 113 |
+
# Store old values.
|
| 114 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 115 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 116 |
+
|
| 117 |
+
# Shift so the last timestep is zero.
|
| 118 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 119 |
+
|
| 120 |
+
# Scale so the first timestep is back to the old value.
|
| 121 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 122 |
+
|
| 123 |
+
# Convert alphas_bar_sqrt to betas
|
| 124 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
| 125 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
| 126 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 127 |
+
betas = 1 - alphas
|
| 128 |
+
|
| 129 |
+
return betas
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 133 |
+
"""
|
| 134 |
+
Ancestral sampling with Euler method steps.
|
| 135 |
+
|
| 136 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 137 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 141 |
+
The number of diffusion steps to train the model.
|
| 142 |
+
beta_start (`float`, defaults to 0.0001):
|
| 143 |
+
The starting `beta` value of inference.
|
| 144 |
+
beta_end (`float`, defaults to 0.02):
|
| 145 |
+
The final `beta` value.
|
| 146 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
| 147 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
| 148 |
+
`linear` or `scaled_linear`.
|
| 149 |
+
trained_betas (`np.ndarray`, *optional*):
|
| 150 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
| 151 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
| 152 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
| 153 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
| 154 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
| 155 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 156 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 157 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 158 |
+
steps_offset (`int`, defaults to 0):
|
| 159 |
+
An offset added to the inference steps, as required by some model families.
|
| 160 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
| 161 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
| 162 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
| 163 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 167 |
+
order = 1
|
| 168 |
+
|
| 169 |
+
@register_to_config
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
num_train_timesteps: int = 1000,
|
| 173 |
+
beta_start: float = 0.0001,
|
| 174 |
+
beta_end: float = 0.02,
|
| 175 |
+
beta_schedule: str = "linear",
|
| 176 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
| 177 |
+
prediction_type: str = "epsilon",
|
| 178 |
+
timestep_spacing: str = "linspace",
|
| 179 |
+
steps_offset: int = 0,
|
| 180 |
+
rescale_betas_zero_snr: bool = False,
|
| 181 |
+
):
|
| 182 |
+
if trained_betas is not None:
|
| 183 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
| 184 |
+
elif beta_schedule == "linear":
|
| 185 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
| 186 |
+
elif beta_schedule == "scaled_linear":
|
| 187 |
+
# this schedule is very specific to the latent diffusion model.
|
| 188 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
| 189 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
| 190 |
+
# Glide cosine schedule
|
| 191 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
| 192 |
+
else:
|
| 193 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
| 194 |
+
|
| 195 |
+
if rescale_betas_zero_snr:
|
| 196 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
| 197 |
+
|
| 198 |
+
self.alphas = 1.0 - self.betas
|
| 199 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 200 |
+
|
| 201 |
+
if rescale_betas_zero_snr:
|
| 202 |
+
# Close to 0 without being 0 so first sigma is not inf
|
| 203 |
+
# FP16 smallest positive subnormal works well here
|
| 204 |
+
self.alphas_cumprod[-1] = 2**-24
|
| 205 |
+
|
| 206 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
| 207 |
+
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
|
| 208 |
+
self.sigmas = torch.from_numpy(sigmas)
|
| 209 |
+
|
| 210 |
+
# setable values
|
| 211 |
+
self.num_inference_steps = None
|
| 212 |
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
| 213 |
+
self.timesteps = torch.from_numpy(timesteps)
|
| 214 |
+
self.is_scale_input_called = False
|
| 215 |
+
|
| 216 |
+
self._step_index = None
|
| 217 |
+
self._begin_index = None
|
| 218 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 219 |
+
|
| 220 |
+
@property
|
| 221 |
+
def init_noise_sigma(self):
|
| 222 |
+
# standard deviation of the initial noise distribution
|
| 223 |
+
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
| 224 |
+
return self.sigmas.max()
|
| 225 |
+
|
| 226 |
+
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
| 227 |
+
|
| 228 |
+
@property
|
| 229 |
+
def step_index(self):
|
| 230 |
+
"""
|
| 231 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 232 |
+
"""
|
| 233 |
+
return self._step_index
|
| 234 |
+
|
| 235 |
+
@property
|
| 236 |
+
def begin_index(self):
|
| 237 |
+
"""
|
| 238 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 239 |
+
"""
|
| 240 |
+
return self._begin_index
|
| 241 |
+
|
| 242 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 243 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 244 |
+
"""
|
| 245 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
begin_index (`int`):
|
| 249 |
+
The begin index for the scheduler.
|
| 250 |
+
"""
|
| 251 |
+
self._begin_index = begin_index
|
| 252 |
+
|
| 253 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 254 |
+
"""
|
| 255 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 256 |
+
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
sample (`torch.Tensor`):
|
| 260 |
+
The input sample.
|
| 261 |
+
timestep (`int`, *optional*):
|
| 262 |
+
The current timestep in the diffusion chain.
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
`torch.Tensor`:
|
| 266 |
+
A scaled input sample.
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
if self.step_index is None:
|
| 270 |
+
self._init_step_index(timestep)
|
| 271 |
+
|
| 272 |
+
sigma = self.sigmas[self.step_index]
|
| 273 |
+
sample = sample / ((sigma**2 + 1) ** 0.5)
|
| 274 |
+
self.is_scale_input_called = True
|
| 275 |
+
return sample
|
| 276 |
+
|
| 277 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
| 278 |
+
"""
|
| 279 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
num_inference_steps (`int`):
|
| 283 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 284 |
+
device (`str` or `torch.device`, *optional*):
|
| 285 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 286 |
+
"""
|
| 287 |
+
self.num_inference_steps = num_inference_steps
|
| 288 |
+
|
| 289 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
| 290 |
+
if self.config.timestep_spacing == "linspace":
|
| 291 |
+
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
|
| 292 |
+
::-1
|
| 293 |
+
].copy()
|
| 294 |
+
elif self.config.timestep_spacing == "leading":
|
| 295 |
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
| 296 |
+
# creates integer timesteps by multiplying by ratio
|
| 297 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
| 298 |
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
| 299 |
+
timesteps += self.config.steps_offset
|
| 300 |
+
elif self.config.timestep_spacing == "trailing":
|
| 301 |
+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
| 302 |
+
# creates integer timesteps by multiplying by ratio
|
| 303 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
| 304 |
+
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
| 305 |
+
timesteps -= 1
|
| 306 |
+
else:
|
| 307 |
+
raise ValueError(
|
| 308 |
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
| 312 |
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
| 313 |
+
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
| 314 |
+
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
| 315 |
+
|
| 316 |
+
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
| 317 |
+
self._step_index = None
|
| 318 |
+
self._begin_index = None
|
| 319 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 320 |
+
|
| 321 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
| 322 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 323 |
+
if schedule_timesteps is None:
|
| 324 |
+
schedule_timesteps = self.timesteps
|
| 325 |
+
|
| 326 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 327 |
+
|
| 328 |
+
# The sigma index that is taken for the **very** first `step`
|
| 329 |
+
# is always the second index (or the last index if there is only 1)
|
| 330 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 331 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 332 |
+
pos = 1 if len(indices) > 1 else 0
|
| 333 |
+
|
| 334 |
+
return indices[pos].item()
|
| 335 |
+
|
| 336 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
| 337 |
+
def _init_step_index(self, timestep):
|
| 338 |
+
if self.begin_index is None:
|
| 339 |
+
if isinstance(timestep, torch.Tensor):
|
| 340 |
+
timestep = timestep.to(self.timesteps.device)
|
| 341 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 342 |
+
else:
|
| 343 |
+
self._step_index = self._begin_index
|
| 344 |
+
|
| 345 |
+
def step(
|
| 346 |
+
self,
|
| 347 |
+
model_output: torch.Tensor,
|
| 348 |
+
timestep: Union[float, torch.Tensor],
|
| 349 |
+
sample: torch.Tensor,
|
| 350 |
+
generator: Optional[torch.Generator] = None,
|
| 351 |
+
return_dict: bool = True,
|
| 352 |
+
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
| 353 |
+
"""
|
| 354 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 355 |
+
process from the learned model outputs (most often the predicted noise).
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
model_output (`torch.Tensor`):
|
| 359 |
+
The direct output from learned diffusion model.
|
| 360 |
+
timestep (`float`):
|
| 361 |
+
The current discrete timestep in the diffusion chain.
|
| 362 |
+
sample (`torch.Tensor`):
|
| 363 |
+
A current instance of a sample created by the diffusion process.
|
| 364 |
+
generator (`torch.Generator`, *optional*):
|
| 365 |
+
A random number generator.
|
| 366 |
+
return_dict (`bool`):
|
| 367 |
+
Whether or not to return a
|
| 368 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
| 372 |
+
If return_dict is `True`,
|
| 373 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
| 374 |
+
otherwise a tuple is returned where the first element is the sample tensor.
|
| 375 |
+
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
| 379 |
+
raise ValueError(
|
| 380 |
+
(
|
| 381 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 382 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 383 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 384 |
+
),
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
if not self.is_scale_input_called:
|
| 388 |
+
logger.warning(
|
| 389 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
| 390 |
+
"See `StableDiffusionPipeline` for a usage example."
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if self.step_index is None:
|
| 394 |
+
self._init_step_index(timestep)
|
| 395 |
+
|
| 396 |
+
sigma = self.sigmas[self.step_index]
|
| 397 |
+
|
| 398 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 399 |
+
sample = sample.to(torch.float32)
|
| 400 |
+
|
| 401 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 402 |
+
if self.config.prediction_type == "epsilon":
|
| 403 |
+
pred_original_sample = sample - sigma * model_output
|
| 404 |
+
elif self.config.prediction_type == "v_prediction":
|
| 405 |
+
# * c_out + input * c_skip
|
| 406 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
| 407 |
+
elif self.config.prediction_type == "sample":
|
| 408 |
+
raise NotImplementedError("prediction_type not implemented yet: sample")
|
| 409 |
+
else:
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
sigma_from = self.sigmas[self.step_index]
|
| 415 |
+
sigma_to = self.sigmas[self.step_index + 1]
|
| 416 |
+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
| 417 |
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
| 418 |
+
|
| 419 |
+
# 2. Convert to an ODE derivative
|
| 420 |
+
derivative = (sample - pred_original_sample) / sigma
|
| 421 |
+
|
| 422 |
+
dt = sigma_down - sigma
|
| 423 |
+
|
| 424 |
+
prev_sample = sample + derivative * dt
|
| 425 |
+
|
| 426 |
+
device = model_output.device
|
| 427 |
+
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
|
| 428 |
+
|
| 429 |
+
prev_sample = prev_sample + noise * sigma_up
|
| 430 |
+
|
| 431 |
+
# Cast sample back to model compatible dtype
|
| 432 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 433 |
+
|
| 434 |
+
# upon completion increase step index by one
|
| 435 |
+
self._step_index += 1
|
| 436 |
+
|
| 437 |
+
if not return_dict:
|
| 438 |
+
return (
|
| 439 |
+
prev_sample,
|
| 440 |
+
pred_original_sample,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
return EulerAncestralDiscreteSchedulerOutput(
|
| 444 |
+
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
| 448 |
+
def add_noise(
|
| 449 |
+
self,
|
| 450 |
+
original_samples: torch.Tensor,
|
| 451 |
+
noise: torch.Tensor,
|
| 452 |
+
timesteps: torch.Tensor,
|
| 453 |
+
) -> torch.Tensor:
|
| 454 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 455 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 456 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 457 |
+
# mps does not support float64
|
| 458 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
| 459 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
| 460 |
+
else:
|
| 461 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
| 462 |
+
timesteps = timesteps.to(original_samples.device)
|
| 463 |
+
|
| 464 |
+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
| 465 |
+
if self.begin_index is None:
|
| 466 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
| 467 |
+
elif self.step_index is not None:
|
| 468 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 469 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
| 470 |
+
else:
|
| 471 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 472 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
| 473 |
+
|
| 474 |
+
sigma = sigmas[step_indices].flatten()
|
| 475 |
+
while len(sigma.shape) < len(original_samples.shape):
|
| 476 |
+
sigma = sigma.unsqueeze(-1)
|
| 477 |
+
|
| 478 |
+
noisy_samples = original_samples + noise * sigma
|
| 479 |
+
return noisy_samples
|
| 480 |
+
|
| 481 |
+
def __len__(self):
|
| 482 |
+
return self.config.num_train_timesteps
|
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_discrete.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from ..utils import BaseOutput, is_scipy_available, logging
|
| 24 |
+
from ..utils.torch_utils import randn_tensor
|
| 25 |
+
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if is_scipy_available():
|
| 29 |
+
import scipy.stats
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
|
| 36 |
+
class EulerDiscreteSchedulerOutput(BaseOutput):
|
| 37 |
+
"""
|
| 38 |
+
Output class for the scheduler's `step` function output.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 42 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 43 |
+
denoising loop.
|
| 44 |
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 45 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
| 46 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
prev_sample: torch.Tensor
|
| 50 |
+
pred_original_sample: Optional[torch.Tensor] = None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
| 54 |
+
def betas_for_alpha_bar(
|
| 55 |
+
num_diffusion_timesteps,
|
| 56 |
+
max_beta=0.999,
|
| 57 |
+
alpha_transform_type="cosine",
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
| 61 |
+
(1-beta) over time from t = [0,1].
|
| 62 |
+
|
| 63 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
| 64 |
+
to that part of the diffusion process.
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
| 69 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
| 70 |
+
prevent singularities.
|
| 71 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
| 72 |
+
Choose from `cosine` or `exp`
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
| 76 |
+
"""
|
| 77 |
+
if alpha_transform_type == "cosine":
|
| 78 |
+
|
| 79 |
+
def alpha_bar_fn(t):
|
| 80 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 81 |
+
|
| 82 |
+
elif alpha_transform_type == "exp":
|
| 83 |
+
|
| 84 |
+
def alpha_bar_fn(t):
|
| 85 |
+
return math.exp(t * -12.0)
|
| 86 |
+
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
| 89 |
+
|
| 90 |
+
betas = []
|
| 91 |
+
for i in range(num_diffusion_timesteps):
|
| 92 |
+
t1 = i / num_diffusion_timesteps
|
| 93 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 94 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
| 95 |
+
return torch.tensor(betas, dtype=torch.float32)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
| 99 |
+
def rescale_zero_terminal_snr(betas):
|
| 100 |
+
"""
|
| 101 |
+
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
betas (`torch.Tensor`):
|
| 106 |
+
the betas that the scheduler is being initialized with.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
| 110 |
+
"""
|
| 111 |
+
# Convert betas to alphas_bar_sqrt
|
| 112 |
+
alphas = 1.0 - betas
|
| 113 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 114 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 115 |
+
|
| 116 |
+
# Store old values.
|
| 117 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 118 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 119 |
+
|
| 120 |
+
# Shift so the last timestep is zero.
|
| 121 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 122 |
+
|
| 123 |
+
# Scale so the first timestep is back to the old value.
|
| 124 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 125 |
+
|
| 126 |
+
# Convert alphas_bar_sqrt to betas
|
| 127 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
| 128 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
| 129 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 130 |
+
betas = 1 - alphas
|
| 131 |
+
|
| 132 |
+
return betas
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 136 |
+
"""
|
| 137 |
+
Euler scheduler.
|
| 138 |
+
|
| 139 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 140 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 144 |
+
The number of diffusion steps to train the model.
|
| 145 |
+
beta_start (`float`, defaults to 0.0001):
|
| 146 |
+
The starting `beta` value of inference.
|
| 147 |
+
beta_end (`float`, defaults to 0.02):
|
| 148 |
+
The final `beta` value.
|
| 149 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
| 150 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
| 151 |
+
`linear` or `scaled_linear`.
|
| 152 |
+
trained_betas (`np.ndarray`, *optional*):
|
| 153 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
| 154 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
| 155 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
| 156 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
| 157 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
| 158 |
+
interpolation_type(`str`, defaults to `"linear"`, *optional*):
|
| 159 |
+
The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
|
| 160 |
+
`"linear"` or `"log_linear"`.
|
| 161 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
| 162 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
| 163 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
| 164 |
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
| 165 |
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
| 166 |
+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
| 167 |
+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
| 168 |
+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
| 169 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 170 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 171 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 172 |
+
steps_offset (`int`, defaults to 0):
|
| 173 |
+
An offset added to the inference steps, as required by some model families.
|
| 174 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
| 175 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
| 176 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
| 177 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
| 178 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
| 179 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 180 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 184 |
+
order = 1
|
| 185 |
+
|
| 186 |
+
@register_to_config
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
num_train_timesteps: int = 1000,
|
| 190 |
+
beta_start: float = 0.0001,
|
| 191 |
+
beta_end: float = 0.02,
|
| 192 |
+
beta_schedule: str = "linear",
|
| 193 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
| 194 |
+
prediction_type: str = "epsilon",
|
| 195 |
+
interpolation_type: str = "linear",
|
| 196 |
+
use_karras_sigmas: Optional[bool] = False,
|
| 197 |
+
use_exponential_sigmas: Optional[bool] = False,
|
| 198 |
+
use_beta_sigmas: Optional[bool] = False,
|
| 199 |
+
sigma_min: Optional[float] = None,
|
| 200 |
+
sigma_max: Optional[float] = None,
|
| 201 |
+
timestep_spacing: str = "linspace",
|
| 202 |
+
timestep_type: str = "discrete", # can be "discrete" or "continuous"
|
| 203 |
+
steps_offset: int = 0,
|
| 204 |
+
rescale_betas_zero_snr: bool = False,
|
| 205 |
+
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
| 206 |
+
):
|
| 207 |
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
| 208 |
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
| 209 |
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
| 210 |
+
raise ValueError(
|
| 211 |
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
| 212 |
+
)
|
| 213 |
+
if trained_betas is not None:
|
| 214 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
| 215 |
+
elif beta_schedule == "linear":
|
| 216 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
| 217 |
+
elif beta_schedule == "scaled_linear":
|
| 218 |
+
# this schedule is very specific to the latent diffusion model.
|
| 219 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
| 220 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
| 221 |
+
# Glide cosine schedule
|
| 222 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
| 223 |
+
else:
|
| 224 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
| 225 |
+
|
| 226 |
+
if rescale_betas_zero_snr:
|
| 227 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
| 228 |
+
|
| 229 |
+
self.alphas = 1.0 - self.betas
|
| 230 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 231 |
+
|
| 232 |
+
if rescale_betas_zero_snr:
|
| 233 |
+
# Close to 0 without being 0 so first sigma is not inf
|
| 234 |
+
# FP16 smallest positive subnormal works well here
|
| 235 |
+
self.alphas_cumprod[-1] = 2**-24
|
| 236 |
+
|
| 237 |
+
sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)
|
| 238 |
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
| 239 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
| 240 |
+
|
| 241 |
+
# setable values
|
| 242 |
+
self.num_inference_steps = None
|
| 243 |
+
|
| 244 |
+
# TODO: Support the full EDM scalings for all prediction types and timestep types
|
| 245 |
+
if timestep_type == "continuous" and prediction_type == "v_prediction":
|
| 246 |
+
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
|
| 247 |
+
else:
|
| 248 |
+
self.timesteps = timesteps
|
| 249 |
+
|
| 250 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
| 251 |
+
|
| 252 |
+
self.is_scale_input_called = False
|
| 253 |
+
self.use_karras_sigmas = use_karras_sigmas
|
| 254 |
+
self.use_exponential_sigmas = use_exponential_sigmas
|
| 255 |
+
self.use_beta_sigmas = use_beta_sigmas
|
| 256 |
+
|
| 257 |
+
self._step_index = None
|
| 258 |
+
self._begin_index = None
|
| 259 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 260 |
+
|
| 261 |
+
@property
|
| 262 |
+
def init_noise_sigma(self):
|
| 263 |
+
# standard deviation of the initial noise distribution
|
| 264 |
+
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
|
| 265 |
+
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
| 266 |
+
return max_sigma
|
| 267 |
+
|
| 268 |
+
return (max_sigma**2 + 1) ** 0.5
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def step_index(self):
|
| 272 |
+
"""
|
| 273 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 274 |
+
"""
|
| 275 |
+
return self._step_index
|
| 276 |
+
|
| 277 |
+
@property
|
| 278 |
+
def begin_index(self):
|
| 279 |
+
"""
|
| 280 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 281 |
+
"""
|
| 282 |
+
return self._begin_index
|
| 283 |
+
|
| 284 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 285 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 286 |
+
"""
|
| 287 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
begin_index (`int`):
|
| 291 |
+
The begin index for the scheduler.
|
| 292 |
+
"""
|
| 293 |
+
self._begin_index = begin_index
|
| 294 |
+
|
| 295 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 296 |
+
"""
|
| 297 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 298 |
+
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
sample (`torch.Tensor`):
|
| 302 |
+
The input sample.
|
| 303 |
+
timestep (`int`, *optional*):
|
| 304 |
+
The current timestep in the diffusion chain.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
`torch.Tensor`:
|
| 308 |
+
A scaled input sample.
|
| 309 |
+
"""
|
| 310 |
+
if self.step_index is None:
|
| 311 |
+
self._init_step_index(timestep)
|
| 312 |
+
|
| 313 |
+
sigma = self.sigmas[self.step_index]
|
| 314 |
+
sample = sample / ((sigma**2 + 1) ** 0.5)
|
| 315 |
+
|
| 316 |
+
self.is_scale_input_called = True
|
| 317 |
+
return sample
|
| 318 |
+
|
| 319 |
+
def set_timesteps(
|
| 320 |
+
self,
|
| 321 |
+
num_inference_steps: int = None,
|
| 322 |
+
device: Union[str, torch.device] = None,
|
| 323 |
+
timesteps: Optional[List[int]] = None,
|
| 324 |
+
sigmas: Optional[List[float]] = None,
|
| 325 |
+
):
|
| 326 |
+
"""
|
| 327 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
num_inference_steps (`int`):
|
| 331 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 332 |
+
device (`str` or `torch.device`, *optional*):
|
| 333 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 334 |
+
timesteps (`List[int]`, *optional*):
|
| 335 |
+
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
| 336 |
+
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
| 337 |
+
must be `None`, and `timestep_spacing` attribute will be ignored.
|
| 338 |
+
sigmas (`List[float]`, *optional*):
|
| 339 |
+
Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas
|
| 340 |
+
will be generated based on the relevant scheduler attributes. If `sigmas` is passed,
|
| 341 |
+
`num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the
|
| 342 |
+
custom sigmas schedule.
|
| 343 |
+
"""
|
| 344 |
+
|
| 345 |
+
if timesteps is not None and sigmas is not None:
|
| 346 |
+
raise ValueError("Only one of `timesteps` or `sigmas` should be set.")
|
| 347 |
+
if num_inference_steps is None and timesteps is None and sigmas is None:
|
| 348 |
+
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.")
|
| 349 |
+
if num_inference_steps is not None and (timesteps is not None or sigmas is not None):
|
| 350 |
+
raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
|
| 351 |
+
if timesteps is not None and self.config.use_karras_sigmas:
|
| 352 |
+
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
|
| 353 |
+
if timesteps is not None and self.config.use_exponential_sigmas:
|
| 354 |
+
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
|
| 355 |
+
if timesteps is not None and self.config.use_beta_sigmas:
|
| 356 |
+
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
|
| 357 |
+
if (
|
| 358 |
+
timesteps is not None
|
| 359 |
+
and self.config.timestep_type == "continuous"
|
| 360 |
+
and self.config.prediction_type == "v_prediction"
|
| 361 |
+
):
|
| 362 |
+
raise ValueError(
|
| 363 |
+
"Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
if num_inference_steps is None:
|
| 367 |
+
num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
|
| 368 |
+
self.num_inference_steps = num_inference_steps
|
| 369 |
+
|
| 370 |
+
if sigmas is not None:
|
| 371 |
+
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
|
| 372 |
+
sigmas = np.array(sigmas).astype(np.float32)
|
| 373 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
|
| 374 |
+
|
| 375 |
+
else:
|
| 376 |
+
if timesteps is not None:
|
| 377 |
+
timesteps = np.array(timesteps).astype(np.float32)
|
| 378 |
+
else:
|
| 379 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
| 380 |
+
if self.config.timestep_spacing == "linspace":
|
| 381 |
+
timesteps = np.linspace(
|
| 382 |
+
0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
|
| 383 |
+
)[::-1].copy()
|
| 384 |
+
elif self.config.timestep_spacing == "leading":
|
| 385 |
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
| 386 |
+
# creates integer timesteps by multiplying by ratio
|
| 387 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
| 388 |
+
timesteps = (
|
| 389 |
+
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
| 390 |
+
)
|
| 391 |
+
timesteps += self.config.steps_offset
|
| 392 |
+
elif self.config.timestep_spacing == "trailing":
|
| 393 |
+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
| 394 |
+
# creates integer timesteps by multiplying by ratio
|
| 395 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
| 396 |
+
timesteps = (
|
| 397 |
+
(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
| 398 |
+
)
|
| 399 |
+
timesteps -= 1
|
| 400 |
+
else:
|
| 401 |
+
raise ValueError(
|
| 402 |
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
| 406 |
+
log_sigmas = np.log(sigmas)
|
| 407 |
+
if self.config.interpolation_type == "linear":
|
| 408 |
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
| 409 |
+
elif self.config.interpolation_type == "log_linear":
|
| 410 |
+
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
|
| 411 |
+
else:
|
| 412 |
+
raise ValueError(
|
| 413 |
+
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
|
| 414 |
+
" 'linear' or 'log_linear'"
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
if self.config.use_karras_sigmas:
|
| 418 |
+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
| 419 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
| 420 |
+
|
| 421 |
+
elif self.config.use_exponential_sigmas:
|
| 422 |
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
| 423 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
| 424 |
+
|
| 425 |
+
elif self.config.use_beta_sigmas:
|
| 426 |
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
| 427 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
| 428 |
+
|
| 429 |
+
if self.config.final_sigmas_type == "sigma_min":
|
| 430 |
+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
| 431 |
+
elif self.config.final_sigmas_type == "zero":
|
| 432 |
+
sigma_last = 0
|
| 433 |
+
else:
|
| 434 |
+
raise ValueError(
|
| 435 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
| 439 |
+
|
| 440 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
| 441 |
+
|
| 442 |
+
# TODO: Support the full EDM scalings for all prediction types and timestep types
|
| 443 |
+
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
|
| 444 |
+
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
|
| 445 |
+
else:
|
| 446 |
+
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
|
| 447 |
+
|
| 448 |
+
self._step_index = None
|
| 449 |
+
self._begin_index = None
|
| 450 |
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 451 |
+
|
| 452 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
| 453 |
+
# get log sigma
|
| 454 |
+
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
| 455 |
+
|
| 456 |
+
# get distribution
|
| 457 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
| 458 |
+
|
| 459 |
+
# get sigmas range
|
| 460 |
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
| 461 |
+
high_idx = low_idx + 1
|
| 462 |
+
|
| 463 |
+
low = log_sigmas[low_idx]
|
| 464 |
+
high = log_sigmas[high_idx]
|
| 465 |
+
|
| 466 |
+
# interpolate sigmas
|
| 467 |
+
w = (low - log_sigma) / (low - high)
|
| 468 |
+
w = np.clip(w, 0, 1)
|
| 469 |
+
|
| 470 |
+
# transform interpolation to time range
|
| 471 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 472 |
+
t = t.reshape(sigma.shape)
|
| 473 |
+
return t
|
| 474 |
+
|
| 475 |
+
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
|
| 476 |
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
| 477 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 478 |
+
|
| 479 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
| 480 |
+
# TODO: Add this logic to the other schedulers
|
| 481 |
+
if hasattr(self.config, "sigma_min"):
|
| 482 |
+
sigma_min = self.config.sigma_min
|
| 483 |
+
else:
|
| 484 |
+
sigma_min = None
|
| 485 |
+
|
| 486 |
+
if hasattr(self.config, "sigma_max"):
|
| 487 |
+
sigma_max = self.config.sigma_max
|
| 488 |
+
else:
|
| 489 |
+
sigma_max = None
|
| 490 |
+
|
| 491 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
| 492 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
| 493 |
+
|
| 494 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
| 495 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
| 496 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 497 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 498 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 499 |
+
return sigmas
|
| 500 |
+
|
| 501 |
+
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
|
| 502 |
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
| 503 |
+
"""Constructs an exponential noise schedule."""
|
| 504 |
+
|
| 505 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
| 506 |
+
# TODO: Add this logic to the other schedulers
|
| 507 |
+
if hasattr(self.config, "sigma_min"):
|
| 508 |
+
sigma_min = self.config.sigma_min
|
| 509 |
+
else:
|
| 510 |
+
sigma_min = None
|
| 511 |
+
|
| 512 |
+
if hasattr(self.config, "sigma_max"):
|
| 513 |
+
sigma_max = self.config.sigma_max
|
| 514 |
+
else:
|
| 515 |
+
sigma_max = None
|
| 516 |
+
|
| 517 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
| 518 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
| 519 |
+
|
| 520 |
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
| 521 |
+
return sigmas
|
| 522 |
+
|
| 523 |
+
def _convert_to_beta(
|
| 524 |
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
| 525 |
+
) -> torch.Tensor:
|
| 526 |
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
| 527 |
+
|
| 528 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
| 529 |
+
# TODO: Add this logic to the other schedulers
|
| 530 |
+
if hasattr(self.config, "sigma_min"):
|
| 531 |
+
sigma_min = self.config.sigma_min
|
| 532 |
+
else:
|
| 533 |
+
sigma_min = None
|
| 534 |
+
|
| 535 |
+
if hasattr(self.config, "sigma_max"):
|
| 536 |
+
sigma_max = self.config.sigma_max
|
| 537 |
+
else:
|
| 538 |
+
sigma_max = None
|
| 539 |
+
|
| 540 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
| 541 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
| 542 |
+
|
| 543 |
+
sigmas = np.array(
|
| 544 |
+
[
|
| 545 |
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
| 546 |
+
for ppf in [
|
| 547 |
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
| 548 |
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
| 549 |
+
]
|
| 550 |
+
]
|
| 551 |
+
)
|
| 552 |
+
return sigmas
|
| 553 |
+
|
| 554 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 555 |
+
if schedule_timesteps is None:
|
| 556 |
+
schedule_timesteps = self.timesteps
|
| 557 |
+
|
| 558 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 559 |
+
|
| 560 |
+
# The sigma index that is taken for the **very** first `step`
|
| 561 |
+
# is always the second index (or the last index if there is only 1)
|
| 562 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 563 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 564 |
+
pos = 1 if len(indices) > 1 else 0
|
| 565 |
+
|
| 566 |
+
return indices[pos].item()
|
| 567 |
+
|
| 568 |
+
def _init_step_index(self, timestep):
|
| 569 |
+
if self.begin_index is None:
|
| 570 |
+
if isinstance(timestep, torch.Tensor):
|
| 571 |
+
timestep = timestep.to(self.timesteps.device)
|
| 572 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 573 |
+
else:
|
| 574 |
+
self._step_index = self._begin_index
|
| 575 |
+
|
| 576 |
+
def step(
|
| 577 |
+
self,
|
| 578 |
+
model_output: torch.Tensor,
|
| 579 |
+
timestep: Union[float, torch.Tensor],
|
| 580 |
+
sample: torch.Tensor,
|
| 581 |
+
s_churn: float = 0.0,
|
| 582 |
+
s_tmin: float = 0.0,
|
| 583 |
+
s_tmax: float = float("inf"),
|
| 584 |
+
s_noise: float = 1.0,
|
| 585 |
+
generator: Optional[torch.Generator] = None,
|
| 586 |
+
return_dict: bool = True,
|
| 587 |
+
) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
|
| 588 |
+
"""
|
| 589 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 590 |
+
process from the learned model outputs (most often the predicted noise).
|
| 591 |
+
|
| 592 |
+
Args:
|
| 593 |
+
model_output (`torch.Tensor`):
|
| 594 |
+
The direct output from learned diffusion model.
|
| 595 |
+
timestep (`float`):
|
| 596 |
+
The current discrete timestep in the diffusion chain.
|
| 597 |
+
sample (`torch.Tensor`):
|
| 598 |
+
A current instance of a sample created by the diffusion process.
|
| 599 |
+
s_churn (`float`):
|
| 600 |
+
s_tmin (`float`):
|
| 601 |
+
s_tmax (`float`):
|
| 602 |
+
s_noise (`float`, defaults to 1.0):
|
| 603 |
+
Scaling factor for noise added to the sample.
|
| 604 |
+
generator (`torch.Generator`, *optional*):
|
| 605 |
+
A random number generator.
|
| 606 |
+
return_dict (`bool`):
|
| 607 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
| 608 |
+
tuple.
|
| 609 |
+
|
| 610 |
+
Returns:
|
| 611 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
| 612 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
| 613 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
| 614 |
+
"""
|
| 615 |
+
|
| 616 |
+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
| 617 |
+
raise ValueError(
|
| 618 |
+
(
|
| 619 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 620 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 621 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 622 |
+
),
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
if not self.is_scale_input_called:
|
| 626 |
+
logger.warning(
|
| 627 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
| 628 |
+
"See `StableDiffusionPipeline` for a usage example."
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
if self.step_index is None:
|
| 632 |
+
self._init_step_index(timestep)
|
| 633 |
+
|
| 634 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 635 |
+
sample = sample.to(torch.float32)
|
| 636 |
+
|
| 637 |
+
sigma = self.sigmas[self.step_index]
|
| 638 |
+
|
| 639 |
+
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
| 640 |
+
|
| 641 |
+
sigma_hat = sigma * (gamma + 1)
|
| 642 |
+
|
| 643 |
+
if gamma > 0:
|
| 644 |
+
noise = randn_tensor(
|
| 645 |
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
| 646 |
+
)
|
| 647 |
+
eps = noise * s_noise
|
| 648 |
+
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
| 649 |
+
|
| 650 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 651 |
+
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
| 652 |
+
# backwards compatibility
|
| 653 |
+
if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
|
| 654 |
+
pred_original_sample = model_output
|
| 655 |
+
elif self.config.prediction_type == "epsilon":
|
| 656 |
+
pred_original_sample = sample - sigma_hat * model_output
|
| 657 |
+
elif self.config.prediction_type == "v_prediction":
|
| 658 |
+
# denoised = model_output * c_out + input * c_skip
|
| 659 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
| 660 |
+
else:
|
| 661 |
+
raise ValueError(
|
| 662 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
# 2. Convert to an ODE derivative
|
| 666 |
+
derivative = (sample - pred_original_sample) / sigma_hat
|
| 667 |
+
|
| 668 |
+
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
| 669 |
+
|
| 670 |
+
prev_sample = sample + derivative * dt
|
| 671 |
+
|
| 672 |
+
# Cast sample back to model compatible dtype
|
| 673 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 674 |
+
|
| 675 |
+
# upon completion increase step index by one
|
| 676 |
+
self._step_index += 1
|
| 677 |
+
|
| 678 |
+
if not return_dict:
|
| 679 |
+
return (
|
| 680 |
+
prev_sample,
|
| 681 |
+
pred_original_sample,
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
| 685 |
+
|
| 686 |
+
def add_noise(
|
| 687 |
+
self,
|
| 688 |
+
original_samples: torch.Tensor,
|
| 689 |
+
noise: torch.Tensor,
|
| 690 |
+
timesteps: torch.Tensor,
|
| 691 |
+
) -> torch.Tensor:
|
| 692 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 693 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 694 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 695 |
+
# mps does not support float64
|
| 696 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
| 697 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
| 698 |
+
else:
|
| 699 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
| 700 |
+
timesteps = timesteps.to(original_samples.device)
|
| 701 |
+
|
| 702 |
+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
| 703 |
+
if self.begin_index is None:
|
| 704 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
| 705 |
+
elif self.step_index is not None:
|
| 706 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 707 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
| 708 |
+
else:
|
| 709 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 710 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
| 711 |
+
|
| 712 |
+
sigma = sigmas[step_indices].flatten()
|
| 713 |
+
while len(sigma.shape) < len(original_samples.shape):
|
| 714 |
+
sigma = sigma.unsqueeze(-1)
|
| 715 |
+
|
| 716 |
+
noisy_samples = original_samples + noise * sigma
|
| 717 |
+
return noisy_samples
|
| 718 |
+
|
| 719 |
+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
| 720 |
+
if (
|
| 721 |
+
isinstance(timesteps, int)
|
| 722 |
+
or isinstance(timesteps, torch.IntTensor)
|
| 723 |
+
or isinstance(timesteps, torch.LongTensor)
|
| 724 |
+
):
|
| 725 |
+
raise ValueError(
|
| 726 |
+
(
|
| 727 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 728 |
+
" `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass"
|
| 729 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 730 |
+
),
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
if sample.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 734 |
+
# mps does not support float64
|
| 735 |
+
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
| 736 |
+
timesteps = timesteps.to(sample.device, dtype=torch.float32)
|
| 737 |
+
else:
|
| 738 |
+
schedule_timesteps = self.timesteps.to(sample.device)
|
| 739 |
+
timesteps = timesteps.to(sample.device)
|
| 740 |
+
|
| 741 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
| 742 |
+
alphas_cumprod = self.alphas_cumprod.to(sample)
|
| 743 |
+
sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5
|
| 744 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
| 745 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
| 746 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
| 747 |
+
|
| 748 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5
|
| 749 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
| 750 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
| 751 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
| 752 |
+
|
| 753 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
| 754 |
+
return velocity
|
| 755 |
+
|
| 756 |
+
def __len__(self):
|
| 757 |
+
return self.config.num_train_timesteps
|
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_discrete_flax.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import flax
|
| 19 |
+
import jax.numpy as jnp
|
| 20 |
+
|
| 21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
| 22 |
+
from .scheduling_utils_flax import (
|
| 23 |
+
CommonSchedulerState,
|
| 24 |
+
FlaxKarrasDiffusionSchedulers,
|
| 25 |
+
FlaxSchedulerMixin,
|
| 26 |
+
FlaxSchedulerOutput,
|
| 27 |
+
broadcast_to_shape_from_left,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@flax.struct.dataclass
|
| 32 |
+
class EulerDiscreteSchedulerState:
|
| 33 |
+
common: CommonSchedulerState
|
| 34 |
+
|
| 35 |
+
# setable values
|
| 36 |
+
init_noise_sigma: jnp.ndarray
|
| 37 |
+
timesteps: jnp.ndarray
|
| 38 |
+
sigmas: jnp.ndarray
|
| 39 |
+
num_inference_steps: Optional[int] = None
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def create(
|
| 43 |
+
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
|
| 44 |
+
):
|
| 45 |
+
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class FlaxEulerDiscreteSchedulerOutput(FlaxSchedulerOutput):
|
| 50 |
+
state: EulerDiscreteSchedulerState
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
| 54 |
+
"""
|
| 55 |
+
Euler scheduler (Algorithm 2) from Karras et al. (2022) https://huggingface.co/papers/2206.00364. . Based on the
|
| 56 |
+
original k-diffusion implementation by Katherine Crowson:
|
| 57 |
+
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
| 61 |
+
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
| 62 |
+
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
| 63 |
+
[`~SchedulerMixin.from_pretrained`] functions.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
| 67 |
+
beta_start (`float`): the starting `beta` value of inference.
|
| 68 |
+
beta_end (`float`): the final `beta` value.
|
| 69 |
+
beta_schedule (`str`):
|
| 70 |
+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
| 71 |
+
`linear` or `scaled_linear`.
|
| 72 |
+
trained_betas (`jnp.ndarray`, optional):
|
| 73 |
+
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
| 74 |
+
prediction_type (`str`, default `epsilon`, optional):
|
| 75 |
+
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
| 76 |
+
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
| 77 |
+
https://imagen.research.google/video/paper.pdf)
|
| 78 |
+
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
| 79 |
+
the `dtype` used for params and computation.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
|
| 83 |
+
|
| 84 |
+
dtype: jnp.dtype
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def has_state(self):
|
| 88 |
+
return True
|
| 89 |
+
|
| 90 |
+
@register_to_config
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
num_train_timesteps: int = 1000,
|
| 94 |
+
beta_start: float = 0.0001,
|
| 95 |
+
beta_end: float = 0.02,
|
| 96 |
+
beta_schedule: str = "linear",
|
| 97 |
+
trained_betas: Optional[jnp.ndarray] = None,
|
| 98 |
+
prediction_type: str = "epsilon",
|
| 99 |
+
timestep_spacing: str = "linspace",
|
| 100 |
+
dtype: jnp.dtype = jnp.float32,
|
| 101 |
+
):
|
| 102 |
+
self.dtype = dtype
|
| 103 |
+
|
| 104 |
+
def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState:
|
| 105 |
+
if common is None:
|
| 106 |
+
common = CommonSchedulerState.create(self)
|
| 107 |
+
|
| 108 |
+
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
|
| 109 |
+
sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
|
| 110 |
+
sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas)
|
| 111 |
+
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
|
| 112 |
+
|
| 113 |
+
# standard deviation of the initial noise distribution
|
| 114 |
+
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
| 115 |
+
init_noise_sigma = sigmas.max()
|
| 116 |
+
else:
|
| 117 |
+
init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5
|
| 118 |
+
|
| 119 |
+
return EulerDiscreteSchedulerState.create(
|
| 120 |
+
common=common,
|
| 121 |
+
init_noise_sigma=init_noise_sigma,
|
| 122 |
+
timesteps=timesteps,
|
| 123 |
+
sigmas=sigmas,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
|
| 127 |
+
"""
|
| 128 |
+
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
state (`EulerDiscreteSchedulerState`):
|
| 132 |
+
the `FlaxEulerDiscreteScheduler` state data class instance.
|
| 133 |
+
sample (`jnp.ndarray`):
|
| 134 |
+
current instance of sample being created by diffusion process.
|
| 135 |
+
timestep (`int`):
|
| 136 |
+
current discrete timestep in the diffusion chain.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
`jnp.ndarray`: scaled input sample
|
| 140 |
+
"""
|
| 141 |
+
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
|
| 142 |
+
step_index = step_index[0]
|
| 143 |
+
|
| 144 |
+
sigma = state.sigmas[step_index]
|
| 145 |
+
sample = sample / ((sigma**2 + 1) ** 0.5)
|
| 146 |
+
return sample
|
| 147 |
+
|
| 148 |
+
def set_timesteps(
|
| 149 |
+
self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
|
| 150 |
+
) -> EulerDiscreteSchedulerState:
|
| 151 |
+
"""
|
| 152 |
+
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
state (`EulerDiscreteSchedulerState`):
|
| 156 |
+
the `FlaxEulerDiscreteScheduler` state data class instance.
|
| 157 |
+
num_inference_steps (`int`):
|
| 158 |
+
the number of diffusion steps used when generating samples with a pre-trained model.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
if self.config.timestep_spacing == "linspace":
|
| 162 |
+
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
|
| 163 |
+
elif self.config.timestep_spacing == "leading":
|
| 164 |
+
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
| 165 |
+
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
| 166 |
+
timesteps += 1
|
| 167 |
+
else:
|
| 168 |
+
raise ValueError(
|
| 169 |
+
f"timestep_spacing must be one of ['linspace', 'leading'], got {self.config.timestep_spacing}"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
|
| 173 |
+
sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas)
|
| 174 |
+
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
|
| 175 |
+
|
| 176 |
+
# standard deviation of the initial noise distribution
|
| 177 |
+
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
| 178 |
+
init_noise_sigma = sigmas.max()
|
| 179 |
+
else:
|
| 180 |
+
init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5
|
| 181 |
+
|
| 182 |
+
return state.replace(
|
| 183 |
+
timesteps=timesteps,
|
| 184 |
+
sigmas=sigmas,
|
| 185 |
+
num_inference_steps=num_inference_steps,
|
| 186 |
+
init_noise_sigma=init_noise_sigma,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def step(
|
| 190 |
+
self,
|
| 191 |
+
state: EulerDiscreteSchedulerState,
|
| 192 |
+
model_output: jnp.ndarray,
|
| 193 |
+
timestep: int,
|
| 194 |
+
sample: jnp.ndarray,
|
| 195 |
+
return_dict: bool = True,
|
| 196 |
+
) -> Union[FlaxEulerDiscreteSchedulerOutput, Tuple]:
|
| 197 |
+
"""
|
| 198 |
+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
| 199 |
+
process from the learned model outputs (most often the predicted noise).
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
state (`EulerDiscreteSchedulerState`):
|
| 203 |
+
the `FlaxEulerDiscreteScheduler` state data class instance.
|
| 204 |
+
model_output (`jnp.ndarray`): direct output from learned diffusion model.
|
| 205 |
+
timestep (`int`): current discrete timestep in the diffusion chain.
|
| 206 |
+
sample (`jnp.ndarray`):
|
| 207 |
+
current instance of sample being created by diffusion process.
|
| 208 |
+
order: coefficient for multi-step inference.
|
| 209 |
+
return_dict (`bool`): option for returning tuple rather than FlaxEulerDiscreteScheduler class
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
[`FlaxEulerDiscreteScheduler`] or `tuple`: [`FlaxEulerDiscreteScheduler`] if `return_dict` is True,
|
| 213 |
+
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
| 214 |
+
|
| 215 |
+
"""
|
| 216 |
+
if state.num_inference_steps is None:
|
| 217 |
+
raise ValueError(
|
| 218 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
|
| 222 |
+
step_index = step_index[0]
|
| 223 |
+
|
| 224 |
+
sigma = state.sigmas[step_index]
|
| 225 |
+
|
| 226 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 227 |
+
if self.config.prediction_type == "epsilon":
|
| 228 |
+
pred_original_sample = sample - sigma * model_output
|
| 229 |
+
elif self.config.prediction_type == "v_prediction":
|
| 230 |
+
# * c_out + input * c_skip
|
| 231 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
| 232 |
+
else:
|
| 233 |
+
raise ValueError(
|
| 234 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# 2. Convert to an ODE derivative
|
| 238 |
+
derivative = (sample - pred_original_sample) / sigma
|
| 239 |
+
|
| 240 |
+
# dt = sigma_down - sigma
|
| 241 |
+
dt = state.sigmas[step_index + 1] - sigma
|
| 242 |
+
|
| 243 |
+
prev_sample = sample + derivative * dt
|
| 244 |
+
|
| 245 |
+
if not return_dict:
|
| 246 |
+
return (prev_sample, state)
|
| 247 |
+
|
| 248 |
+
return FlaxEulerDiscreteSchedulerOutput(prev_sample=prev_sample, state=state)
|
| 249 |
+
|
| 250 |
+
def add_noise(
|
| 251 |
+
self,
|
| 252 |
+
state: EulerDiscreteSchedulerState,
|
| 253 |
+
original_samples: jnp.ndarray,
|
| 254 |
+
noise: jnp.ndarray,
|
| 255 |
+
timesteps: jnp.ndarray,
|
| 256 |
+
) -> jnp.ndarray:
|
| 257 |
+
sigma = state.sigmas[timesteps].flatten()
|
| 258 |
+
sigma = broadcast_to_shape_from_left(sigma, noise.shape)
|
| 259 |
+
|
| 260 |
+
noisy_samples = original_samples + noise * sigma
|
| 261 |
+
|
| 262 |
+
return noisy_samples
|
| 263 |
+
|
| 264 |
+
def __len__(self):
|
| 265 |
+
return self.config.num_train_timesteps
|
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from ..utils import BaseOutput, is_scipy_available, logging
|
| 24 |
+
from .scheduling_utils import SchedulerMixin
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if is_scipy_available():
|
| 28 |
+
import scipy.stats
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
| 35 |
+
"""
|
| 36 |
+
Output class for the scheduler's `step` function output.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 40 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 41 |
+
denoising loop.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
prev_sample: torch.FloatTensor
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 48 |
+
"""
|
| 49 |
+
Euler scheduler.
|
| 50 |
+
|
| 51 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 52 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 56 |
+
The number of diffusion steps to train the model.
|
| 57 |
+
shift (`float`, defaults to 1.0):
|
| 58 |
+
The shift value for the timestep schedule.
|
| 59 |
+
use_dynamic_shifting (`bool`, defaults to False):
|
| 60 |
+
Whether to apply timestep shifting on-the-fly based on the image resolution.
|
| 61 |
+
base_shift (`float`, defaults to 0.5):
|
| 62 |
+
Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent
|
| 63 |
+
with desired output.
|
| 64 |
+
max_shift (`float`, defaults to 1.15):
|
| 65 |
+
Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be
|
| 66 |
+
more exaggerated or stylized.
|
| 67 |
+
base_image_seq_len (`int`, defaults to 256):
|
| 68 |
+
The base image sequence length.
|
| 69 |
+
max_image_seq_len (`int`, defaults to 4096):
|
| 70 |
+
The maximum image sequence length.
|
| 71 |
+
invert_sigmas (`bool`, defaults to False):
|
| 72 |
+
Whether to invert the sigmas.
|
| 73 |
+
shift_terminal (`float`, defaults to None):
|
| 74 |
+
The end value of the shifted timestep schedule.
|
| 75 |
+
use_karras_sigmas (`bool`, defaults to False):
|
| 76 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during sampling.
|
| 77 |
+
use_exponential_sigmas (`bool`, defaults to False):
|
| 78 |
+
Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
|
| 79 |
+
use_beta_sigmas (`bool`, defaults to False):
|
| 80 |
+
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
|
| 81 |
+
time_shift_type (`str`, defaults to "exponential"):
|
| 82 |
+
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
|
| 83 |
+
stochastic_sampling (`bool`, defaults to False):
|
| 84 |
+
Whether to use stochastic sampling.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
_compatibles = []
|
| 88 |
+
order = 1
|
| 89 |
+
|
| 90 |
+
@register_to_config
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
num_train_timesteps: int = 1000,
|
| 94 |
+
shift: float = 1.0,
|
| 95 |
+
use_dynamic_shifting: bool = False,
|
| 96 |
+
base_shift: Optional[float] = 0.5,
|
| 97 |
+
max_shift: Optional[float] = 1.15,
|
| 98 |
+
base_image_seq_len: Optional[int] = 256,
|
| 99 |
+
max_image_seq_len: Optional[int] = 4096,
|
| 100 |
+
invert_sigmas: bool = False,
|
| 101 |
+
shift_terminal: Optional[float] = None,
|
| 102 |
+
use_karras_sigmas: Optional[bool] = False,
|
| 103 |
+
use_exponential_sigmas: Optional[bool] = False,
|
| 104 |
+
use_beta_sigmas: Optional[bool] = False,
|
| 105 |
+
time_shift_type: str = "exponential",
|
| 106 |
+
stochastic_sampling: bool = False,
|
| 107 |
+
):
|
| 108 |
+
if self.config.use_beta_sigmas and not is_scipy_available():
|
| 109 |
+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
| 110 |
+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
| 113 |
+
)
|
| 114 |
+
if time_shift_type not in {"exponential", "linear"}:
|
| 115 |
+
raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.")
|
| 116 |
+
|
| 117 |
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
| 118 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
| 119 |
+
|
| 120 |
+
sigmas = timesteps / num_train_timesteps
|
| 121 |
+
if not use_dynamic_shifting:
|
| 122 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
| 123 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 124 |
+
|
| 125 |
+
self.timesteps = sigmas * num_train_timesteps
|
| 126 |
+
|
| 127 |
+
self._step_index = None
|
| 128 |
+
self._begin_index = None
|
| 129 |
+
|
| 130 |
+
self._shift = shift
|
| 131 |
+
|
| 132 |
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 133 |
+
self.sigma_min = self.sigmas[-1].item()
|
| 134 |
+
self.sigma_max = self.sigmas[0].item()
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def shift(self):
|
| 138 |
+
"""
|
| 139 |
+
The value used for shifting.
|
| 140 |
+
"""
|
| 141 |
+
return self._shift
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def step_index(self):
|
| 145 |
+
"""
|
| 146 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 147 |
+
"""
|
| 148 |
+
return self._step_index
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def begin_index(self):
|
| 152 |
+
"""
|
| 153 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 154 |
+
"""
|
| 155 |
+
return self._begin_index
|
| 156 |
+
|
| 157 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 158 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 159 |
+
"""
|
| 160 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
begin_index (`int`):
|
| 164 |
+
The begin index for the scheduler.
|
| 165 |
+
"""
|
| 166 |
+
self._begin_index = begin_index
|
| 167 |
+
|
| 168 |
+
def set_shift(self, shift: float):
|
| 169 |
+
self._shift = shift
|
| 170 |
+
|
| 171 |
+
def scale_noise(
|
| 172 |
+
self,
|
| 173 |
+
sample: torch.FloatTensor,
|
| 174 |
+
timestep: Union[float, torch.FloatTensor],
|
| 175 |
+
noise: Optional[torch.FloatTensor] = None,
|
| 176 |
+
) -> torch.FloatTensor:
|
| 177 |
+
"""
|
| 178 |
+
Forward process in flow-matching
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
sample (`torch.FloatTensor`):
|
| 182 |
+
The input sample.
|
| 183 |
+
timestep (`int`, *optional*):
|
| 184 |
+
The current timestep in the diffusion chain.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
`torch.FloatTensor`:
|
| 188 |
+
A scaled input sample.
|
| 189 |
+
"""
|
| 190 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 191 |
+
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
| 192 |
+
|
| 193 |
+
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
| 194 |
+
# mps does not support float64
|
| 195 |
+
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
| 196 |
+
timestep = timestep.to(sample.device, dtype=torch.float32)
|
| 197 |
+
else:
|
| 198 |
+
schedule_timesteps = self.timesteps.to(sample.device)
|
| 199 |
+
timestep = timestep.to(sample.device)
|
| 200 |
+
|
| 201 |
+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
| 202 |
+
if self.begin_index is None:
|
| 203 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
|
| 204 |
+
elif self.step_index is not None:
|
| 205 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 206 |
+
step_indices = [self.step_index] * timestep.shape[0]
|
| 207 |
+
else:
|
| 208 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 209 |
+
step_indices = [self.begin_index] * timestep.shape[0]
|
| 210 |
+
|
| 211 |
+
sigma = sigmas[step_indices].flatten()
|
| 212 |
+
while len(sigma.shape) < len(sample.shape):
|
| 213 |
+
sigma = sigma.unsqueeze(-1)
|
| 214 |
+
|
| 215 |
+
sample = sigma * noise + (1.0 - sigma) * sample
|
| 216 |
+
|
| 217 |
+
return sample
|
| 218 |
+
|
| 219 |
+
def _sigma_to_t(self, sigma):
|
| 220 |
+
return sigma * self.config.num_train_timesteps
|
| 221 |
+
|
| 222 |
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
| 223 |
+
if self.config.time_shift_type == "exponential":
|
| 224 |
+
return self._time_shift_exponential(mu, sigma, t)
|
| 225 |
+
elif self.config.time_shift_type == "linear":
|
| 226 |
+
return self._time_shift_linear(mu, sigma, t)
|
| 227 |
+
|
| 228 |
+
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
| 229 |
+
r"""
|
| 230 |
+
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
|
| 231 |
+
value.
|
| 232 |
+
|
| 233 |
+
Reference:
|
| 234 |
+
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
t (`torch.Tensor`):
|
| 238 |
+
A tensor of timesteps to be stretched and shifted.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
`torch.Tensor`:
|
| 242 |
+
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
| 243 |
+
"""
|
| 244 |
+
one_minus_z = 1 - t
|
| 245 |
+
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
| 246 |
+
stretched_t = 1 - (one_minus_z / scale_factor)
|
| 247 |
+
return stretched_t
|
| 248 |
+
|
| 249 |
+
def set_timesteps(
|
| 250 |
+
self,
|
| 251 |
+
num_inference_steps: Optional[int] = None,
|
| 252 |
+
device: Union[str, torch.device] = None,
|
| 253 |
+
sigmas: Optional[List[float]] = None,
|
| 254 |
+
mu: Optional[float] = None,
|
| 255 |
+
timesteps: Optional[List[float]] = None,
|
| 256 |
+
):
|
| 257 |
+
"""
|
| 258 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
num_inference_steps (`int`, *optional*):
|
| 262 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 263 |
+
device (`str` or `torch.device`, *optional*):
|
| 264 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 265 |
+
sigmas (`List[float]`, *optional*):
|
| 266 |
+
Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
|
| 267 |
+
automatically.
|
| 268 |
+
mu (`float`, *optional*):
|
| 269 |
+
Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
|
| 270 |
+
shifting.
|
| 271 |
+
timesteps (`List[float]`, *optional*):
|
| 272 |
+
Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
|
| 273 |
+
automatically.
|
| 274 |
+
"""
|
| 275 |
+
if self.config.use_dynamic_shifting and mu is None:
|
| 276 |
+
raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
|
| 277 |
+
|
| 278 |
+
if sigmas is not None and timesteps is not None:
|
| 279 |
+
if len(sigmas) != len(timesteps):
|
| 280 |
+
raise ValueError("`sigmas` and `timesteps` should have the same length")
|
| 281 |
+
|
| 282 |
+
if num_inference_steps is not None:
|
| 283 |
+
if (sigmas is not None and len(sigmas) != num_inference_steps) or (
|
| 284 |
+
timesteps is not None and len(timesteps) != num_inference_steps
|
| 285 |
+
):
|
| 286 |
+
raise ValueError(
|
| 287 |
+
"`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
|
| 288 |
+
)
|
| 289 |
+
else:
|
| 290 |
+
num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
|
| 291 |
+
|
| 292 |
+
self.num_inference_steps = num_inference_steps
|
| 293 |
+
|
| 294 |
+
# 1. Prepare default sigmas
|
| 295 |
+
is_timesteps_provided = timesteps is not None
|
| 296 |
+
|
| 297 |
+
if is_timesteps_provided:
|
| 298 |
+
timesteps = np.array(timesteps).astype(np.float32)
|
| 299 |
+
|
| 300 |
+
if sigmas is None:
|
| 301 |
+
if timesteps is None:
|
| 302 |
+
timesteps = np.linspace(
|
| 303 |
+
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
| 304 |
+
)
|
| 305 |
+
sigmas = timesteps / self.config.num_train_timesteps
|
| 306 |
+
else:
|
| 307 |
+
sigmas = np.array(sigmas).astype(np.float32)
|
| 308 |
+
num_inference_steps = len(sigmas)
|
| 309 |
+
|
| 310 |
+
# 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
|
| 311 |
+
# "exponential" or "linear" type is applied
|
| 312 |
+
if self.config.use_dynamic_shifting:
|
| 313 |
+
sigmas = self.time_shift(mu, 1.0, sigmas)
|
| 314 |
+
else:
|
| 315 |
+
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
|
| 316 |
+
|
| 317 |
+
# 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
|
| 318 |
+
if self.config.shift_terminal:
|
| 319 |
+
sigmas = self.stretch_shift_to_terminal(sigmas)
|
| 320 |
+
|
| 321 |
+
# 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
|
| 322 |
+
if self.config.use_karras_sigmas:
|
| 323 |
+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
| 324 |
+
elif self.config.use_exponential_sigmas:
|
| 325 |
+
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
| 326 |
+
elif self.config.use_beta_sigmas:
|
| 327 |
+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
| 328 |
+
|
| 329 |
+
# 5. Convert sigmas and timesteps to tensors and move to specified device
|
| 330 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
| 331 |
+
if not is_timesteps_provided:
|
| 332 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
| 333 |
+
else:
|
| 334 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
|
| 335 |
+
|
| 336 |
+
# 6. Append the terminal sigma value.
|
| 337 |
+
# If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
|
| 338 |
+
# `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
|
| 339 |
+
if self.config.invert_sigmas:
|
| 340 |
+
sigmas = 1.0 - sigmas
|
| 341 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
| 342 |
+
sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
|
| 343 |
+
else:
|
| 344 |
+
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
| 345 |
+
|
| 346 |
+
self.timesteps = timesteps
|
| 347 |
+
self.sigmas = sigmas
|
| 348 |
+
self._step_index = None
|
| 349 |
+
self._begin_index = None
|
| 350 |
+
|
| 351 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 352 |
+
if schedule_timesteps is None:
|
| 353 |
+
schedule_timesteps = self.timesteps
|
| 354 |
+
|
| 355 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 356 |
+
|
| 357 |
+
# The sigma index that is taken for the **very** first `step`
|
| 358 |
+
# is always the second index (or the last index if there is only 1)
|
| 359 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 360 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 361 |
+
pos = 1 if len(indices) > 1 else 0
|
| 362 |
+
|
| 363 |
+
return indices[pos].item()
|
| 364 |
+
|
| 365 |
+
def _init_step_index(self, timestep):
|
| 366 |
+
if self.begin_index is None:
|
| 367 |
+
if isinstance(timestep, torch.Tensor):
|
| 368 |
+
timestep = timestep.to(self.timesteps.device)
|
| 369 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 370 |
+
else:
|
| 371 |
+
self._step_index = self._begin_index
|
| 372 |
+
|
| 373 |
+
def step(
|
| 374 |
+
self,
|
| 375 |
+
model_output: torch.FloatTensor,
|
| 376 |
+
timestep: Union[float, torch.FloatTensor],
|
| 377 |
+
sample: torch.FloatTensor,
|
| 378 |
+
s_churn: float = 0.0,
|
| 379 |
+
s_tmin: float = 0.0,
|
| 380 |
+
s_tmax: float = float("inf"),
|
| 381 |
+
s_noise: float = 1.0,
|
| 382 |
+
generator: Optional[torch.Generator] = None,
|
| 383 |
+
per_token_timesteps: Optional[torch.Tensor] = None,
|
| 384 |
+
return_dict: bool = True,
|
| 385 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
| 386 |
+
"""
|
| 387 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 388 |
+
process from the learned model outputs (most often the predicted noise).
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
model_output (`torch.FloatTensor`):
|
| 392 |
+
The direct output from learned diffusion model.
|
| 393 |
+
timestep (`float`):
|
| 394 |
+
The current discrete timestep in the diffusion chain.
|
| 395 |
+
sample (`torch.FloatTensor`):
|
| 396 |
+
A current instance of a sample created by the diffusion process.
|
| 397 |
+
s_churn (`float`):
|
| 398 |
+
s_tmin (`float`):
|
| 399 |
+
s_tmax (`float`):
|
| 400 |
+
s_noise (`float`, defaults to 1.0):
|
| 401 |
+
Scaling factor for noise added to the sample.
|
| 402 |
+
generator (`torch.Generator`, *optional*):
|
| 403 |
+
A random number generator.
|
| 404 |
+
per_token_timesteps (`torch.Tensor`, *optional*):
|
| 405 |
+
The timesteps for each token in the sample.
|
| 406 |
+
return_dict (`bool`):
|
| 407 |
+
Whether or not to return a
|
| 408 |
+
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
|
| 409 |
+
|
| 410 |
+
Returns:
|
| 411 |
+
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`:
|
| 412 |
+
If return_dict is `True`,
|
| 413 |
+
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned,
|
| 414 |
+
otherwise a tuple is returned where the first element is the sample tensor.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
if (
|
| 418 |
+
isinstance(timestep, int)
|
| 419 |
+
or isinstance(timestep, torch.IntTensor)
|
| 420 |
+
or isinstance(timestep, torch.LongTensor)
|
| 421 |
+
):
|
| 422 |
+
raise ValueError(
|
| 423 |
+
(
|
| 424 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 425 |
+
" `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 426 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 427 |
+
),
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
if self.step_index is None:
|
| 431 |
+
self._init_step_index(timestep)
|
| 432 |
+
|
| 433 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 434 |
+
sample = sample.to(torch.float32)
|
| 435 |
+
|
| 436 |
+
if per_token_timesteps is not None:
|
| 437 |
+
per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
|
| 438 |
+
|
| 439 |
+
sigmas = self.sigmas[:, None, None]
|
| 440 |
+
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
|
| 441 |
+
lower_sigmas = lower_mask * sigmas
|
| 442 |
+
lower_sigmas, _ = lower_sigmas.max(dim=0)
|
| 443 |
+
|
| 444 |
+
current_sigma = per_token_sigmas[..., None]
|
| 445 |
+
next_sigma = lower_sigmas[..., None]
|
| 446 |
+
dt = current_sigma - next_sigma
|
| 447 |
+
else:
|
| 448 |
+
sigma_idx = self.step_index
|
| 449 |
+
sigma = self.sigmas[sigma_idx]
|
| 450 |
+
sigma_next = self.sigmas[sigma_idx + 1]
|
| 451 |
+
|
| 452 |
+
current_sigma = sigma
|
| 453 |
+
next_sigma = sigma_next
|
| 454 |
+
dt = sigma_next - sigma
|
| 455 |
+
|
| 456 |
+
if self.config.stochastic_sampling:
|
| 457 |
+
x0 = sample - current_sigma * model_output
|
| 458 |
+
noise = torch.randn_like(sample)
|
| 459 |
+
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
|
| 460 |
+
else:
|
| 461 |
+
prev_sample = sample + dt * model_output
|
| 462 |
+
|
| 463 |
+
# upon completion increase step index by one
|
| 464 |
+
self._step_index += 1
|
| 465 |
+
if per_token_timesteps is None:
|
| 466 |
+
# Cast sample back to model compatible dtype
|
| 467 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 468 |
+
|
| 469 |
+
if not return_dict:
|
| 470 |
+
return (prev_sample,)
|
| 471 |
+
|
| 472 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
| 473 |
+
|
| 474 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
| 475 |
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
| 476 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 477 |
+
|
| 478 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
| 479 |
+
# TODO: Add this logic to the other schedulers
|
| 480 |
+
if hasattr(self.config, "sigma_min"):
|
| 481 |
+
sigma_min = self.config.sigma_min
|
| 482 |
+
else:
|
| 483 |
+
sigma_min = None
|
| 484 |
+
|
| 485 |
+
if hasattr(self.config, "sigma_max"):
|
| 486 |
+
sigma_max = self.config.sigma_max
|
| 487 |
+
else:
|
| 488 |
+
sigma_max = None
|
| 489 |
+
|
| 490 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
| 491 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
| 492 |
+
|
| 493 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
| 494 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
| 495 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 496 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 497 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 498 |
+
return sigmas
|
| 499 |
+
|
| 500 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
| 501 |
+
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
| 502 |
+
"""Constructs an exponential noise schedule."""
|
| 503 |
+
|
| 504 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
| 505 |
+
# TODO: Add this logic to the other schedulers
|
| 506 |
+
if hasattr(self.config, "sigma_min"):
|
| 507 |
+
sigma_min = self.config.sigma_min
|
| 508 |
+
else:
|
| 509 |
+
sigma_min = None
|
| 510 |
+
|
| 511 |
+
if hasattr(self.config, "sigma_max"):
|
| 512 |
+
sigma_max = self.config.sigma_max
|
| 513 |
+
else:
|
| 514 |
+
sigma_max = None
|
| 515 |
+
|
| 516 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
| 517 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
| 518 |
+
|
| 519 |
+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
| 520 |
+
return sigmas
|
| 521 |
+
|
| 522 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
| 523 |
+
def _convert_to_beta(
|
| 524 |
+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
| 525 |
+
) -> torch.Tensor:
|
| 526 |
+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
| 527 |
+
|
| 528 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
| 529 |
+
# TODO: Add this logic to the other schedulers
|
| 530 |
+
if hasattr(self.config, "sigma_min"):
|
| 531 |
+
sigma_min = self.config.sigma_min
|
| 532 |
+
else:
|
| 533 |
+
sigma_min = None
|
| 534 |
+
|
| 535 |
+
if hasattr(self.config, "sigma_max"):
|
| 536 |
+
sigma_max = self.config.sigma_max
|
| 537 |
+
else:
|
| 538 |
+
sigma_max = None
|
| 539 |
+
|
| 540 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
| 541 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
| 542 |
+
|
| 543 |
+
sigmas = np.array(
|
| 544 |
+
[
|
| 545 |
+
sigma_min + (ppf * (sigma_max - sigma_min))
|
| 546 |
+
for ppf in [
|
| 547 |
+
scipy.stats.beta.ppf(timestep, alpha, beta)
|
| 548 |
+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
| 549 |
+
]
|
| 550 |
+
]
|
| 551 |
+
)
|
| 552 |
+
return sigmas
|
| 553 |
+
|
| 554 |
+
def _time_shift_exponential(self, mu, sigma, t):
|
| 555 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 556 |
+
|
| 557 |
+
def _time_shift_linear(self, mu, sigma, t):
|
| 558 |
+
return mu / (mu + (1 / t - 1) ** sigma)
|
| 559 |
+
|
| 560 |
+
def __len__(self):
|
| 561 |
+
return self.config.num_train_timesteps
|
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
| 22 |
+
from ..utils import BaseOutput, logging
|
| 23 |
+
from ..utils.torch_utils import randn_tensor
|
| 24 |
+
from .scheduling_utils import SchedulerMixin
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
|
| 32 |
+
"""
|
| 33 |
+
Output class for the scheduler's `step` function output.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 37 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 38 |
+
denoising loop.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
prev_sample: torch.FloatTensor
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 45 |
+
"""
|
| 46 |
+
Heun scheduler.
|
| 47 |
+
|
| 48 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 49 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 53 |
+
The number of diffusion steps to train the model.
|
| 54 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 55 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 56 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 57 |
+
shift (`float`, defaults to 1.0):
|
| 58 |
+
The shift value for the timestep schedule.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
_compatibles = []
|
| 62 |
+
order = 2
|
| 63 |
+
|
| 64 |
+
@register_to_config
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
num_train_timesteps: int = 1000,
|
| 68 |
+
shift: float = 1.0,
|
| 69 |
+
):
|
| 70 |
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
| 71 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
| 72 |
+
|
| 73 |
+
sigmas = timesteps / num_train_timesteps
|
| 74 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 75 |
+
|
| 76 |
+
self.timesteps = sigmas * num_train_timesteps
|
| 77 |
+
|
| 78 |
+
self._step_index = None
|
| 79 |
+
self._begin_index = None
|
| 80 |
+
|
| 81 |
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 82 |
+
self.sigma_min = self.sigmas[-1].item()
|
| 83 |
+
self.sigma_max = self.sigmas[0].item()
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def step_index(self):
|
| 87 |
+
"""
|
| 88 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 89 |
+
"""
|
| 90 |
+
return self._step_index
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def begin_index(self):
|
| 94 |
+
"""
|
| 95 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 96 |
+
"""
|
| 97 |
+
return self._begin_index
|
| 98 |
+
|
| 99 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 100 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 101 |
+
"""
|
| 102 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
begin_index (`int`):
|
| 106 |
+
The begin index for the scheduler.
|
| 107 |
+
"""
|
| 108 |
+
self._begin_index = begin_index
|
| 109 |
+
|
| 110 |
+
def scale_noise(
|
| 111 |
+
self,
|
| 112 |
+
sample: torch.FloatTensor,
|
| 113 |
+
timestep: Union[float, torch.FloatTensor],
|
| 114 |
+
noise: Optional[torch.FloatTensor] = None,
|
| 115 |
+
) -> torch.FloatTensor:
|
| 116 |
+
"""
|
| 117 |
+
Forward process in flow-matching
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
sample (`torch.FloatTensor`):
|
| 121 |
+
The input sample.
|
| 122 |
+
timestep (`int`, *optional*):
|
| 123 |
+
The current timestep in the diffusion chain.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
`torch.FloatTensor`:
|
| 127 |
+
A scaled input sample.
|
| 128 |
+
"""
|
| 129 |
+
if self.step_index is None:
|
| 130 |
+
self._init_step_index(timestep)
|
| 131 |
+
|
| 132 |
+
sigma = self.sigmas[self.step_index]
|
| 133 |
+
sample = sigma * noise + (1.0 - sigma) * sample
|
| 134 |
+
|
| 135 |
+
return sample
|
| 136 |
+
|
| 137 |
+
def _sigma_to_t(self, sigma):
|
| 138 |
+
return sigma * self.config.num_train_timesteps
|
| 139 |
+
|
| 140 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
| 141 |
+
"""
|
| 142 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
num_inference_steps (`int`):
|
| 146 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 147 |
+
device (`str` or `torch.device`, *optional*):
|
| 148 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 149 |
+
"""
|
| 150 |
+
self.num_inference_steps = num_inference_steps
|
| 151 |
+
|
| 152 |
+
timesteps = np.linspace(
|
| 153 |
+
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
sigmas = timesteps / self.config.num_train_timesteps
|
| 157 |
+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
| 158 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
| 159 |
+
|
| 160 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
| 161 |
+
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
| 162 |
+
self.timesteps = timesteps.to(device=device)
|
| 163 |
+
|
| 164 |
+
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
| 165 |
+
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
|
| 166 |
+
|
| 167 |
+
# empty dt and derivative
|
| 168 |
+
self.prev_derivative = None
|
| 169 |
+
self.dt = None
|
| 170 |
+
|
| 171 |
+
self._step_index = None
|
| 172 |
+
self._begin_index = None
|
| 173 |
+
|
| 174 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 175 |
+
if schedule_timesteps is None:
|
| 176 |
+
schedule_timesteps = self.timesteps
|
| 177 |
+
|
| 178 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 179 |
+
|
| 180 |
+
# The sigma index that is taken for the **very** first `step`
|
| 181 |
+
# is always the second index (or the last index if there is only 1)
|
| 182 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 183 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 184 |
+
pos = 1 if len(indices) > 1 else 0
|
| 185 |
+
|
| 186 |
+
return indices[pos].item()
|
| 187 |
+
|
| 188 |
+
def _init_step_index(self, timestep):
|
| 189 |
+
if self.begin_index is None:
|
| 190 |
+
if isinstance(timestep, torch.Tensor):
|
| 191 |
+
timestep = timestep.to(self.timesteps.device)
|
| 192 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 193 |
+
else:
|
| 194 |
+
self._step_index = self._begin_index
|
| 195 |
+
|
| 196 |
+
@property
|
| 197 |
+
def state_in_first_order(self):
|
| 198 |
+
return self.dt is None
|
| 199 |
+
|
| 200 |
+
def step(
|
| 201 |
+
self,
|
| 202 |
+
model_output: torch.FloatTensor,
|
| 203 |
+
timestep: Union[float, torch.FloatTensor],
|
| 204 |
+
sample: torch.FloatTensor,
|
| 205 |
+
s_churn: float = 0.0,
|
| 206 |
+
s_tmin: float = 0.0,
|
| 207 |
+
s_tmax: float = float("inf"),
|
| 208 |
+
s_noise: float = 1.0,
|
| 209 |
+
generator: Optional[torch.Generator] = None,
|
| 210 |
+
return_dict: bool = True,
|
| 211 |
+
) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
|
| 212 |
+
"""
|
| 213 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 214 |
+
process from the learned model outputs (most often the predicted noise).
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
model_output (`torch.FloatTensor`):
|
| 218 |
+
The direct output from learned diffusion model.
|
| 219 |
+
timestep (`float`):
|
| 220 |
+
The current discrete timestep in the diffusion chain.
|
| 221 |
+
sample (`torch.FloatTensor`):
|
| 222 |
+
A current instance of a sample created by the diffusion process.
|
| 223 |
+
s_churn (`float`):
|
| 224 |
+
s_tmin (`float`):
|
| 225 |
+
s_tmax (`float`):
|
| 226 |
+
s_noise (`float`, defaults to 1.0):
|
| 227 |
+
Scaling factor for noise added to the sample.
|
| 228 |
+
generator (`torch.Generator`, *optional*):
|
| 229 |
+
A random number generator.
|
| 230 |
+
return_dict (`bool`):
|
| 231 |
+
Whether or not to return a
|
| 232 |
+
[`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] tuple.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
[`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] or `tuple`:
|
| 236 |
+
If return_dict is `True`,
|
| 237 |
+
[`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] is returned,
|
| 238 |
+
otherwise a tuple is returned where the first element is the sample tensor.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
if (
|
| 242 |
+
isinstance(timestep, int)
|
| 243 |
+
or isinstance(timestep, torch.IntTensor)
|
| 244 |
+
or isinstance(timestep, torch.LongTensor)
|
| 245 |
+
):
|
| 246 |
+
raise ValueError(
|
| 247 |
+
(
|
| 248 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 249 |
+
" `FlowMatchHeunDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 250 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 251 |
+
),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if self.step_index is None:
|
| 255 |
+
self._init_step_index(timestep)
|
| 256 |
+
|
| 257 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 258 |
+
sample = sample.to(torch.float32)
|
| 259 |
+
|
| 260 |
+
if self.state_in_first_order:
|
| 261 |
+
sigma = self.sigmas[self.step_index]
|
| 262 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
| 263 |
+
else:
|
| 264 |
+
# 2nd order / Heun's method
|
| 265 |
+
sigma = self.sigmas[self.step_index - 1]
|
| 266 |
+
sigma_next = self.sigmas[self.step_index]
|
| 267 |
+
|
| 268 |
+
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
| 269 |
+
|
| 270 |
+
sigma_hat = sigma * (gamma + 1)
|
| 271 |
+
|
| 272 |
+
if gamma > 0:
|
| 273 |
+
noise = randn_tensor(
|
| 274 |
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
| 275 |
+
)
|
| 276 |
+
eps = noise * s_noise
|
| 277 |
+
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
| 278 |
+
|
| 279 |
+
if self.state_in_first_order:
|
| 280 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 281 |
+
denoised = sample - model_output * sigma
|
| 282 |
+
# 2. convert to an ODE derivative for 1st order
|
| 283 |
+
derivative = (sample - denoised) / sigma_hat
|
| 284 |
+
# 3. Delta timestep
|
| 285 |
+
dt = sigma_next - sigma_hat
|
| 286 |
+
|
| 287 |
+
# store for 2nd order step
|
| 288 |
+
self.prev_derivative = derivative
|
| 289 |
+
self.dt = dt
|
| 290 |
+
self.sample = sample
|
| 291 |
+
else:
|
| 292 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 293 |
+
denoised = sample - model_output * sigma_next
|
| 294 |
+
# 2. 2nd order / Heun's method
|
| 295 |
+
derivative = (sample - denoised) / sigma_next
|
| 296 |
+
derivative = 0.5 * (self.prev_derivative + derivative)
|
| 297 |
+
|
| 298 |
+
# 3. take prev timestep & sample
|
| 299 |
+
dt = self.dt
|
| 300 |
+
sample = self.sample
|
| 301 |
+
|
| 302 |
+
# free dt and derivative
|
| 303 |
+
# Note, this puts the scheduler in "first order mode"
|
| 304 |
+
self.prev_derivative = None
|
| 305 |
+
self.dt = None
|
| 306 |
+
self.sample = None
|
| 307 |
+
|
| 308 |
+
prev_sample = sample + derivative * dt
|
| 309 |
+
# Cast sample back to model compatible dtype
|
| 310 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 311 |
+
|
| 312 |
+
# upon completion increase step index by one
|
| 313 |
+
self._step_index += 1
|
| 314 |
+
|
| 315 |
+
if not return_dict:
|
| 316 |
+
return (prev_sample,)
|
| 317 |
+
|
| 318 |
+
return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
|
| 319 |
+
|
| 320 |
+
def __len__(self):
|
| 321 |
+
return self.config.num_train_timesteps
|
pythonProject/.venv/Lib/site-packages/fsspec/__init__.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import caching
|
| 2 |
+
from ._version import __version__ # noqa: F401
|
| 3 |
+
from .callbacks import Callback
|
| 4 |
+
from .compression import available_compressions
|
| 5 |
+
from .core import get_fs_token_paths, open, open_files, open_local, url_to_fs
|
| 6 |
+
from .exceptions import FSTimeoutError
|
| 7 |
+
from .mapping import FSMap, get_mapper
|
| 8 |
+
from .registry import (
|
| 9 |
+
available_protocols,
|
| 10 |
+
filesystem,
|
| 11 |
+
get_filesystem_class,
|
| 12 |
+
register_implementation,
|
| 13 |
+
registry,
|
| 14 |
+
)
|
| 15 |
+
from .spec import AbstractFileSystem
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"AbstractFileSystem",
|
| 19 |
+
"FSTimeoutError",
|
| 20 |
+
"FSMap",
|
| 21 |
+
"filesystem",
|
| 22 |
+
"register_implementation",
|
| 23 |
+
"get_filesystem_class",
|
| 24 |
+
"get_fs_token_paths",
|
| 25 |
+
"get_mapper",
|
| 26 |
+
"open",
|
| 27 |
+
"open_files",
|
| 28 |
+
"open_local",
|
| 29 |
+
"registry",
|
| 30 |
+
"caching",
|
| 31 |
+
"Callback",
|
| 32 |
+
"available_protocols",
|
| 33 |
+
"available_compressions",
|
| 34 |
+
"url_to_fs",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def process_entries():
|
| 39 |
+
try:
|
| 40 |
+
from importlib.metadata import entry_points
|
| 41 |
+
except ImportError:
|
| 42 |
+
return
|
| 43 |
+
if entry_points is not None:
|
| 44 |
+
try:
|
| 45 |
+
eps = entry_points()
|
| 46 |
+
except TypeError:
|
| 47 |
+
pass # importlib-metadata < 0.8
|
| 48 |
+
else:
|
| 49 |
+
if hasattr(eps, "select"): # Python 3.10+ / importlib_metadata >= 3.9.0
|
| 50 |
+
specs = eps.select(group="fsspec.specs")
|
| 51 |
+
else:
|
| 52 |
+
specs = eps.get("fsspec.specs", [])
|
| 53 |
+
registered_names = {}
|
| 54 |
+
for spec in specs:
|
| 55 |
+
err_msg = f"Unable to load filesystem from {spec}"
|
| 56 |
+
name = spec.name
|
| 57 |
+
if name in registered_names:
|
| 58 |
+
continue
|
| 59 |
+
registered_names[name] = True
|
| 60 |
+
register_implementation(
|
| 61 |
+
name,
|
| 62 |
+
spec.value.replace(":", "."),
|
| 63 |
+
errtxt=err_msg,
|
| 64 |
+
# We take our implementations as the ones to overload with if
|
| 65 |
+
# for some reason we encounter some, may be the same, already
|
| 66 |
+
# registered
|
| 67 |
+
clobber=True,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
process_entries()
|
pythonProject/.venv/Lib/site-packages/fsspec/_version.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file generated by setuptools-scm
|
| 2 |
+
# don't change, don't track in version control
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"__version__",
|
| 6 |
+
"__version_tuple__",
|
| 7 |
+
"version",
|
| 8 |
+
"version_tuple",
|
| 9 |
+
"__commit_id__",
|
| 10 |
+
"commit_id",
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
TYPE_CHECKING = False
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from typing import Tuple
|
| 16 |
+
from typing import Union
|
| 17 |
+
|
| 18 |
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
| 19 |
+
COMMIT_ID = Union[str, None]
|
| 20 |
+
else:
|
| 21 |
+
VERSION_TUPLE = object
|
| 22 |
+
COMMIT_ID = object
|
| 23 |
+
|
| 24 |
+
version: str
|
| 25 |
+
__version__: str
|
| 26 |
+
__version_tuple__: VERSION_TUPLE
|
| 27 |
+
version_tuple: VERSION_TUPLE
|
| 28 |
+
commit_id: COMMIT_ID
|
| 29 |
+
__commit_id__: COMMIT_ID
|
| 30 |
+
|
| 31 |
+
__version__ = version = '2025.9.0'
|
| 32 |
+
__version_tuple__ = version_tuple = (2025, 9, 0)
|
| 33 |
+
|
| 34 |
+
__commit_id__ = commit_id = None
|
pythonProject/.venv/Lib/site-packages/fsspec/implementations/arrow.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import errno
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import secrets
|
| 5 |
+
import shutil
|
| 6 |
+
from contextlib import suppress
|
| 7 |
+
from functools import cached_property, wraps
|
| 8 |
+
from urllib.parse import parse_qs
|
| 9 |
+
|
| 10 |
+
from fsspec.spec import AbstractFileSystem
|
| 11 |
+
from fsspec.utils import (
|
| 12 |
+
get_package_version_without_import,
|
| 13 |
+
infer_storage_options,
|
| 14 |
+
mirror_from,
|
| 15 |
+
tokenize,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def wrap_exceptions(func):
|
| 20 |
+
@wraps(func)
|
| 21 |
+
def wrapper(*args, **kwargs):
|
| 22 |
+
try:
|
| 23 |
+
return func(*args, **kwargs)
|
| 24 |
+
except OSError as exception:
|
| 25 |
+
if not exception.args:
|
| 26 |
+
raise
|
| 27 |
+
|
| 28 |
+
message, *args = exception.args
|
| 29 |
+
if isinstance(message, str) and "does not exist" in message:
|
| 30 |
+
raise FileNotFoundError(errno.ENOENT, message) from exception
|
| 31 |
+
else:
|
| 32 |
+
raise
|
| 33 |
+
|
| 34 |
+
return wrapper
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
PYARROW_VERSION = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ArrowFSWrapper(AbstractFileSystem):
|
| 41 |
+
"""FSSpec-compatible wrapper of pyarrow.fs.FileSystem.
|
| 42 |
+
|
| 43 |
+
Parameters
|
| 44 |
+
----------
|
| 45 |
+
fs : pyarrow.fs.FileSystem
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
root_marker = "/"
|
| 50 |
+
|
| 51 |
+
def __init__(self, fs, **kwargs):
|
| 52 |
+
global PYARROW_VERSION
|
| 53 |
+
PYARROW_VERSION = get_package_version_without_import("pyarrow")
|
| 54 |
+
self.fs = fs
|
| 55 |
+
super().__init__(**kwargs)
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def protocol(self):
|
| 59 |
+
return self.fs.type_name
|
| 60 |
+
|
| 61 |
+
@cached_property
|
| 62 |
+
def fsid(self):
|
| 63 |
+
return "hdfs_" + tokenize(self.fs.host, self.fs.port)
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def _strip_protocol(cls, path):
|
| 67 |
+
ops = infer_storage_options(path)
|
| 68 |
+
path = ops["path"]
|
| 69 |
+
if path.startswith("//"):
|
| 70 |
+
# special case for "hdfs://path" (without the triple slash)
|
| 71 |
+
path = path[1:]
|
| 72 |
+
return path
|
| 73 |
+
|
| 74 |
+
def ls(self, path, detail=False, **kwargs):
|
| 75 |
+
path = self._strip_protocol(path)
|
| 76 |
+
from pyarrow.fs import FileSelector
|
| 77 |
+
|
| 78 |
+
entries = [
|
| 79 |
+
self._make_entry(entry)
|
| 80 |
+
for entry in self.fs.get_file_info(FileSelector(path))
|
| 81 |
+
]
|
| 82 |
+
if detail:
|
| 83 |
+
return entries
|
| 84 |
+
else:
|
| 85 |
+
return [entry["name"] for entry in entries]
|
| 86 |
+
|
| 87 |
+
def info(self, path, **kwargs):
|
| 88 |
+
path = self._strip_protocol(path)
|
| 89 |
+
[info] = self.fs.get_file_info([path])
|
| 90 |
+
return self._make_entry(info)
|
| 91 |
+
|
| 92 |
+
def exists(self, path):
|
| 93 |
+
path = self._strip_protocol(path)
|
| 94 |
+
try:
|
| 95 |
+
self.info(path)
|
| 96 |
+
except FileNotFoundError:
|
| 97 |
+
return False
|
| 98 |
+
else:
|
| 99 |
+
return True
|
| 100 |
+
|
| 101 |
+
def _make_entry(self, info):
|
| 102 |
+
from pyarrow.fs import FileType
|
| 103 |
+
|
| 104 |
+
if info.type is FileType.Directory:
|
| 105 |
+
kind = "directory"
|
| 106 |
+
elif info.type is FileType.File:
|
| 107 |
+
kind = "file"
|
| 108 |
+
elif info.type is FileType.NotFound:
|
| 109 |
+
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), info.path)
|
| 110 |
+
else:
|
| 111 |
+
kind = "other"
|
| 112 |
+
|
| 113 |
+
return {
|
| 114 |
+
"name": info.path,
|
| 115 |
+
"size": info.size,
|
| 116 |
+
"type": kind,
|
| 117 |
+
"mtime": info.mtime,
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
@wrap_exceptions
|
| 121 |
+
def cp_file(self, path1, path2, **kwargs):
|
| 122 |
+
path1 = self._strip_protocol(path1).rstrip("/")
|
| 123 |
+
path2 = self._strip_protocol(path2).rstrip("/")
|
| 124 |
+
|
| 125 |
+
with self._open(path1, "rb") as lstream:
|
| 126 |
+
tmp_fname = f"{path2}.tmp.{secrets.token_hex(6)}"
|
| 127 |
+
try:
|
| 128 |
+
with self.open(tmp_fname, "wb") as rstream:
|
| 129 |
+
shutil.copyfileobj(lstream, rstream)
|
| 130 |
+
self.fs.move(tmp_fname, path2)
|
| 131 |
+
except BaseException:
|
| 132 |
+
with suppress(FileNotFoundError):
|
| 133 |
+
self.fs.delete_file(tmp_fname)
|
| 134 |
+
raise
|
| 135 |
+
|
| 136 |
+
@wrap_exceptions
|
| 137 |
+
def mv(self, path1, path2, **kwargs):
|
| 138 |
+
path1 = self._strip_protocol(path1).rstrip("/")
|
| 139 |
+
path2 = self._strip_protocol(path2).rstrip("/")
|
| 140 |
+
self.fs.move(path1, path2)
|
| 141 |
+
|
| 142 |
+
@wrap_exceptions
|
| 143 |
+
def rm_file(self, path):
|
| 144 |
+
path = self._strip_protocol(path)
|
| 145 |
+
self.fs.delete_file(path)
|
| 146 |
+
|
| 147 |
+
@wrap_exceptions
|
| 148 |
+
def rm(self, path, recursive=False, maxdepth=None):
|
| 149 |
+
path = self._strip_protocol(path).rstrip("/")
|
| 150 |
+
if self.isdir(path):
|
| 151 |
+
if recursive:
|
| 152 |
+
self.fs.delete_dir(path)
|
| 153 |
+
else:
|
| 154 |
+
raise ValueError("Can't delete directories without recursive=False")
|
| 155 |
+
else:
|
| 156 |
+
self.fs.delete_file(path)
|
| 157 |
+
|
| 158 |
+
@wrap_exceptions
|
| 159 |
+
def _open(self, path, mode="rb", block_size=None, seekable=True, **kwargs):
|
| 160 |
+
if mode == "rb":
|
| 161 |
+
if seekable:
|
| 162 |
+
method = self.fs.open_input_file
|
| 163 |
+
else:
|
| 164 |
+
method = self.fs.open_input_stream
|
| 165 |
+
elif mode == "wb":
|
| 166 |
+
method = self.fs.open_output_stream
|
| 167 |
+
elif mode == "ab":
|
| 168 |
+
method = self.fs.open_append_stream
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(f"unsupported mode for Arrow filesystem: {mode!r}")
|
| 171 |
+
|
| 172 |
+
_kwargs = {}
|
| 173 |
+
if mode != "rb" or not seekable:
|
| 174 |
+
if int(PYARROW_VERSION.split(".")[0]) >= 4:
|
| 175 |
+
# disable compression auto-detection
|
| 176 |
+
_kwargs["compression"] = None
|
| 177 |
+
stream = method(path, **_kwargs)
|
| 178 |
+
|
| 179 |
+
return ArrowFile(self, stream, path, mode, block_size, **kwargs)
|
| 180 |
+
|
| 181 |
+
@wrap_exceptions
|
| 182 |
+
def mkdir(self, path, create_parents=True, **kwargs):
|
| 183 |
+
path = self._strip_protocol(path)
|
| 184 |
+
if create_parents:
|
| 185 |
+
self.makedirs(path, exist_ok=True)
|
| 186 |
+
else:
|
| 187 |
+
self.fs.create_dir(path, recursive=False)
|
| 188 |
+
|
| 189 |
+
@wrap_exceptions
|
| 190 |
+
def makedirs(self, path, exist_ok=False):
|
| 191 |
+
path = self._strip_protocol(path)
|
| 192 |
+
self.fs.create_dir(path, recursive=True)
|
| 193 |
+
|
| 194 |
+
@wrap_exceptions
|
| 195 |
+
def rmdir(self, path):
|
| 196 |
+
path = self._strip_protocol(path)
|
| 197 |
+
self.fs.delete_dir(path)
|
| 198 |
+
|
| 199 |
+
@wrap_exceptions
|
| 200 |
+
def modified(self, path):
|
| 201 |
+
path = self._strip_protocol(path)
|
| 202 |
+
return self.fs.get_file_info(path).mtime
|
| 203 |
+
|
| 204 |
+
def cat_file(self, path, start=None, end=None, **kwargs):
|
| 205 |
+
kwargs["seekable"] = start not in [None, 0]
|
| 206 |
+
return super().cat_file(path, start=None, end=None, **kwargs)
|
| 207 |
+
|
| 208 |
+
def get_file(self, rpath, lpath, **kwargs):
|
| 209 |
+
kwargs["seekable"] = False
|
| 210 |
+
super().get_file(rpath, lpath, **kwargs)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@mirror_from(
|
| 214 |
+
"stream",
|
| 215 |
+
[
|
| 216 |
+
"read",
|
| 217 |
+
"seek",
|
| 218 |
+
"tell",
|
| 219 |
+
"write",
|
| 220 |
+
"readable",
|
| 221 |
+
"writable",
|
| 222 |
+
"close",
|
| 223 |
+
"size",
|
| 224 |
+
"seekable",
|
| 225 |
+
],
|
| 226 |
+
)
|
| 227 |
+
class ArrowFile(io.IOBase):
|
| 228 |
+
def __init__(self, fs, stream, path, mode, block_size=None, **kwargs):
|
| 229 |
+
self.path = path
|
| 230 |
+
self.mode = mode
|
| 231 |
+
|
| 232 |
+
self.fs = fs
|
| 233 |
+
self.stream = stream
|
| 234 |
+
|
| 235 |
+
self.blocksize = self.block_size = block_size
|
| 236 |
+
self.kwargs = kwargs
|
| 237 |
+
|
| 238 |
+
def __enter__(self):
|
| 239 |
+
return self
|
| 240 |
+
|
| 241 |
+
def __exit__(self, *args):
|
| 242 |
+
return self.close()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class HadoopFileSystem(ArrowFSWrapper):
|
| 246 |
+
"""A wrapper on top of the pyarrow.fs.HadoopFileSystem
|
| 247 |
+
to connect it's interface with fsspec"""
|
| 248 |
+
|
| 249 |
+
protocol = "hdfs"
|
| 250 |
+
|
| 251 |
+
def __init__(
|
| 252 |
+
self,
|
| 253 |
+
host="default",
|
| 254 |
+
port=0,
|
| 255 |
+
user=None,
|
| 256 |
+
kerb_ticket=None,
|
| 257 |
+
replication=3,
|
| 258 |
+
extra_conf=None,
|
| 259 |
+
**kwargs,
|
| 260 |
+
):
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
Parameters
|
| 264 |
+
----------
|
| 265 |
+
host: str
|
| 266 |
+
Hostname, IP or "default" to try to read from Hadoop config
|
| 267 |
+
port: int
|
| 268 |
+
Port to connect on, or default from Hadoop config if 0
|
| 269 |
+
user: str or None
|
| 270 |
+
If given, connect as this username
|
| 271 |
+
kerb_ticket: str or None
|
| 272 |
+
If given, use this ticket for authentication
|
| 273 |
+
replication: int
|
| 274 |
+
set replication factor of file for write operations. default value is 3.
|
| 275 |
+
extra_conf: None or dict
|
| 276 |
+
Passed on to HadoopFileSystem
|
| 277 |
+
"""
|
| 278 |
+
from pyarrow.fs import HadoopFileSystem
|
| 279 |
+
|
| 280 |
+
fs = HadoopFileSystem(
|
| 281 |
+
host=host,
|
| 282 |
+
port=port,
|
| 283 |
+
user=user,
|
| 284 |
+
kerb_ticket=kerb_ticket,
|
| 285 |
+
replication=replication,
|
| 286 |
+
extra_conf=extra_conf,
|
| 287 |
+
)
|
| 288 |
+
super().__init__(fs=fs, **kwargs)
|
| 289 |
+
|
| 290 |
+
@staticmethod
|
| 291 |
+
def _get_kwargs_from_urls(path):
|
| 292 |
+
ops = infer_storage_options(path)
|
| 293 |
+
out = {}
|
| 294 |
+
if ops.get("host", None):
|
| 295 |
+
out["host"] = ops["host"]
|
| 296 |
+
if ops.get("username", None):
|
| 297 |
+
out["user"] = ops["username"]
|
| 298 |
+
if ops.get("port", None):
|
| 299 |
+
out["port"] = ops["port"]
|
| 300 |
+
if ops.get("url_query", None):
|
| 301 |
+
queries = parse_qs(ops["url_query"])
|
| 302 |
+
if queries.get("replication", None):
|
| 303 |
+
out["replication"] = int(queries["replication"][0])
|
| 304 |
+
return out
|
pythonProject/.venv/Lib/site-packages/fsspec/implementations/asyn_wrapper.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import functools
|
| 3 |
+
import inspect
|
| 4 |
+
|
| 5 |
+
import fsspec
|
| 6 |
+
from fsspec.asyn import AsyncFileSystem, running_async
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def async_wrapper(func, obj=None, semaphore=None):
|
| 10 |
+
"""
|
| 11 |
+
Wraps a synchronous function to make it awaitable.
|
| 12 |
+
|
| 13 |
+
Parameters
|
| 14 |
+
----------
|
| 15 |
+
func : callable
|
| 16 |
+
The synchronous function to wrap.
|
| 17 |
+
obj : object, optional
|
| 18 |
+
The instance to bind the function to, if applicable.
|
| 19 |
+
semaphore : asyncio.Semaphore, optional
|
| 20 |
+
A semaphore to limit concurrent calls.
|
| 21 |
+
|
| 22 |
+
Returns
|
| 23 |
+
-------
|
| 24 |
+
coroutine
|
| 25 |
+
An awaitable version of the function.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
@functools.wraps(func)
|
| 29 |
+
async def wrapper(*args, **kwargs):
|
| 30 |
+
if semaphore:
|
| 31 |
+
async with semaphore:
|
| 32 |
+
return await asyncio.to_thread(func, *args, **kwargs)
|
| 33 |
+
return await asyncio.to_thread(func, *args, **kwargs)
|
| 34 |
+
|
| 35 |
+
return wrapper
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AsyncFileSystemWrapper(AsyncFileSystem):
|
| 39 |
+
"""
|
| 40 |
+
A wrapper class to convert a synchronous filesystem into an asynchronous one.
|
| 41 |
+
|
| 42 |
+
This class takes an existing synchronous filesystem implementation and wraps all
|
| 43 |
+
its methods to provide an asynchronous interface.
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
sync_fs : AbstractFileSystem
|
| 48 |
+
The synchronous filesystem instance to wrap.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
protocol = "asyncwrapper", "async_wrapper"
|
| 52 |
+
cachable = False
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
fs=None,
|
| 57 |
+
asynchronous=None,
|
| 58 |
+
target_protocol=None,
|
| 59 |
+
target_options=None,
|
| 60 |
+
semaphore=None,
|
| 61 |
+
max_concurrent_tasks=None,
|
| 62 |
+
**kwargs,
|
| 63 |
+
):
|
| 64 |
+
if asynchronous is None:
|
| 65 |
+
asynchronous = running_async()
|
| 66 |
+
super().__init__(asynchronous=asynchronous, **kwargs)
|
| 67 |
+
if fs is not None:
|
| 68 |
+
self.sync_fs = fs
|
| 69 |
+
else:
|
| 70 |
+
self.sync_fs = fsspec.filesystem(target_protocol, **target_options)
|
| 71 |
+
self.protocol = self.sync_fs.protocol
|
| 72 |
+
self.semaphore = semaphore
|
| 73 |
+
self._wrap_all_sync_methods()
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def fsid(self):
|
| 77 |
+
return f"async_{self.sync_fs.fsid}"
|
| 78 |
+
|
| 79 |
+
def _wrap_all_sync_methods(self):
|
| 80 |
+
"""
|
| 81 |
+
Wrap all synchronous methods of the underlying filesystem with asynchronous versions.
|
| 82 |
+
"""
|
| 83 |
+
excluded_methods = {"open"}
|
| 84 |
+
for method_name in dir(self.sync_fs):
|
| 85 |
+
if method_name.startswith("_") or method_name in excluded_methods:
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
attr = inspect.getattr_static(self.sync_fs, method_name)
|
| 89 |
+
if isinstance(attr, property):
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
method = getattr(self.sync_fs, method_name)
|
| 93 |
+
if callable(method) and not inspect.iscoroutinefunction(method):
|
| 94 |
+
async_method = async_wrapper(method, obj=self, semaphore=self.semaphore)
|
| 95 |
+
setattr(self, f"_{method_name}", async_method)
|
| 96 |
+
|
| 97 |
+
@classmethod
|
| 98 |
+
def wrap_class(cls, sync_fs_class):
|
| 99 |
+
"""
|
| 100 |
+
Create a new class that can be used to instantiate an AsyncFileSystemWrapper
|
| 101 |
+
with lazy instantiation of the underlying synchronous filesystem.
|
| 102 |
+
|
| 103 |
+
Parameters
|
| 104 |
+
----------
|
| 105 |
+
sync_fs_class : type
|
| 106 |
+
The class of the synchronous filesystem to wrap.
|
| 107 |
+
|
| 108 |
+
Returns
|
| 109 |
+
-------
|
| 110 |
+
type
|
| 111 |
+
A new class that wraps the provided synchronous filesystem class.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
class GeneratedAsyncFileSystemWrapper(cls):
|
| 115 |
+
def __init__(self, *args, **kwargs):
|
| 116 |
+
sync_fs = sync_fs_class(*args, **kwargs)
|
| 117 |
+
super().__init__(sync_fs)
|
| 118 |
+
|
| 119 |
+
GeneratedAsyncFileSystemWrapper.__name__ = (
|
| 120 |
+
f"Async{sync_fs_class.__name__}Wrapper"
|
| 121 |
+
)
|
| 122 |
+
return GeneratedAsyncFileSystemWrapper
|
pythonProject/.venv/Lib/site-packages/fsspec/implementations/cache_mapper.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import abc
|
| 4 |
+
import hashlib
|
| 5 |
+
|
| 6 |
+
from fsspec.implementations.local import make_path_posix
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AbstractCacheMapper(abc.ABC):
|
| 10 |
+
"""Abstract super-class for mappers from remote URLs to local cached
|
| 11 |
+
basenames.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
@abc.abstractmethod
|
| 15 |
+
def __call__(self, path: str) -> str: ...
|
| 16 |
+
|
| 17 |
+
def __eq__(self, other: object) -> bool:
|
| 18 |
+
# Identity only depends on class. When derived classes have attributes
|
| 19 |
+
# they will need to be included.
|
| 20 |
+
return isinstance(other, type(self))
|
| 21 |
+
|
| 22 |
+
def __hash__(self) -> int:
|
| 23 |
+
# Identity only depends on class. When derived classes have attributes
|
| 24 |
+
# they will need to be included.
|
| 25 |
+
return hash(type(self))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BasenameCacheMapper(AbstractCacheMapper):
|
| 29 |
+
"""Cache mapper that uses the basename of the remote URL and a fixed number
|
| 30 |
+
of directory levels above this.
|
| 31 |
+
|
| 32 |
+
The default is zero directory levels, meaning different paths with the same
|
| 33 |
+
basename will have the same cached basename.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, directory_levels: int = 0):
|
| 37 |
+
if directory_levels < 0:
|
| 38 |
+
raise ValueError(
|
| 39 |
+
"BasenameCacheMapper requires zero or positive directory_levels"
|
| 40 |
+
)
|
| 41 |
+
self.directory_levels = directory_levels
|
| 42 |
+
|
| 43 |
+
# Separator for directories when encoded as strings.
|
| 44 |
+
self._separator = "_@_"
|
| 45 |
+
|
| 46 |
+
def __call__(self, path: str) -> str:
|
| 47 |
+
path = make_path_posix(path)
|
| 48 |
+
prefix, *bits = path.rsplit("/", self.directory_levels + 1)
|
| 49 |
+
if bits:
|
| 50 |
+
return self._separator.join(bits)
|
| 51 |
+
else:
|
| 52 |
+
return prefix # No separator found, simple filename
|
| 53 |
+
|
| 54 |
+
def __eq__(self, other: object) -> bool:
|
| 55 |
+
return super().__eq__(other) and self.directory_levels == other.directory_levels
|
| 56 |
+
|
| 57 |
+
def __hash__(self) -> int:
|
| 58 |
+
return super().__hash__() ^ hash(self.directory_levels)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class HashCacheMapper(AbstractCacheMapper):
|
| 62 |
+
"""Cache mapper that uses a hash of the remote URL."""
|
| 63 |
+
|
| 64 |
+
def __call__(self, path: str) -> str:
|
| 65 |
+
return hashlib.sha256(path.encode()).hexdigest()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def create_cache_mapper(same_names: bool) -> AbstractCacheMapper:
|
| 69 |
+
"""Factory method to create cache mapper for backward compatibility with
|
| 70 |
+
``CachingFileSystem`` constructor using ``same_names`` kwarg.
|
| 71 |
+
"""
|
| 72 |
+
if same_names:
|
| 73 |
+
return BasenameCacheMapper()
|
| 74 |
+
else:
|
| 75 |
+
return HashCacheMapper()
|
pythonProject/.venv/Lib/site-packages/fsspec/implementations/cache_metadata.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
import time
|
| 6 |
+
from typing import TYPE_CHECKING
|
| 7 |
+
|
| 8 |
+
from fsspec.utils import atomic_write
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import ujson as json
|
| 12 |
+
except ImportError:
|
| 13 |
+
if not TYPE_CHECKING:
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from collections.abc import Iterator
|
| 18 |
+
from typing import Any, Literal
|
| 19 |
+
|
| 20 |
+
from typing_extensions import TypeAlias
|
| 21 |
+
|
| 22 |
+
from .cached import CachingFileSystem
|
| 23 |
+
|
| 24 |
+
Detail: TypeAlias = dict[str, Any]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CacheMetadata:
|
| 28 |
+
"""Cache metadata.
|
| 29 |
+
|
| 30 |
+
All reading and writing of cache metadata is performed by this class,
|
| 31 |
+
accessing the cached files and blocks is not.
|
| 32 |
+
|
| 33 |
+
Metadata is stored in a single file per storage directory in JSON format.
|
| 34 |
+
For backward compatibility, also reads metadata stored in pickle format
|
| 35 |
+
which is converted to JSON when next saved.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, storage: list[str]):
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
storage: list[str]
|
| 44 |
+
Directories containing cached files, must be at least one. Metadata
|
| 45 |
+
is stored in the last of these directories by convention.
|
| 46 |
+
"""
|
| 47 |
+
if not storage:
|
| 48 |
+
raise ValueError("CacheMetadata expects at least one storage location")
|
| 49 |
+
|
| 50 |
+
self._storage = storage
|
| 51 |
+
self.cached_files: list[Detail] = [{}]
|
| 52 |
+
|
| 53 |
+
# Private attribute to force saving of metadata in pickle format rather than
|
| 54 |
+
# JSON for use in tests to confirm can read both pickle and JSON formats.
|
| 55 |
+
self._force_save_pickle = False
|
| 56 |
+
|
| 57 |
+
def _load(self, fn: str) -> Detail:
|
| 58 |
+
"""Low-level function to load metadata from specific file"""
|
| 59 |
+
try:
|
| 60 |
+
with open(fn, "r") as f:
|
| 61 |
+
loaded = json.load(f)
|
| 62 |
+
except ValueError:
|
| 63 |
+
with open(fn, "rb") as f:
|
| 64 |
+
loaded = pickle.load(f)
|
| 65 |
+
for c in loaded.values():
|
| 66 |
+
if isinstance(c.get("blocks"), list):
|
| 67 |
+
c["blocks"] = set(c["blocks"])
|
| 68 |
+
return loaded
|
| 69 |
+
|
| 70 |
+
def _save(self, metadata_to_save: Detail, fn: str) -> None:
|
| 71 |
+
"""Low-level function to save metadata to specific file"""
|
| 72 |
+
if self._force_save_pickle:
|
| 73 |
+
with atomic_write(fn) as f:
|
| 74 |
+
pickle.dump(metadata_to_save, f)
|
| 75 |
+
else:
|
| 76 |
+
with atomic_write(fn, mode="w") as f:
|
| 77 |
+
json.dump(metadata_to_save, f)
|
| 78 |
+
|
| 79 |
+
def _scan_locations(
|
| 80 |
+
self, writable_only: bool = False
|
| 81 |
+
) -> Iterator[tuple[str, str, bool]]:
|
| 82 |
+
"""Yield locations (filenames) where metadata is stored, and whether
|
| 83 |
+
writable or not.
|
| 84 |
+
|
| 85 |
+
Parameters
|
| 86 |
+
----------
|
| 87 |
+
writable: bool
|
| 88 |
+
Set to True to only yield writable locations.
|
| 89 |
+
|
| 90 |
+
Returns
|
| 91 |
+
-------
|
| 92 |
+
Yields (str, str, bool)
|
| 93 |
+
"""
|
| 94 |
+
n = len(self._storage)
|
| 95 |
+
for i, storage in enumerate(self._storage):
|
| 96 |
+
writable = i == n - 1
|
| 97 |
+
if writable_only and not writable:
|
| 98 |
+
continue
|
| 99 |
+
yield os.path.join(storage, "cache"), storage, writable
|
| 100 |
+
|
| 101 |
+
def check_file(
|
| 102 |
+
self, path: str, cfs: CachingFileSystem | None
|
| 103 |
+
) -> Literal[False] | tuple[Detail, str]:
|
| 104 |
+
"""If path is in cache return its details, otherwise return ``False``.
|
| 105 |
+
|
| 106 |
+
If the optional CachingFileSystem is specified then it is used to
|
| 107 |
+
perform extra checks to reject possible matches, such as if they are
|
| 108 |
+
too old.
|
| 109 |
+
"""
|
| 110 |
+
for (fn, base, _), cache in zip(self._scan_locations(), self.cached_files):
|
| 111 |
+
if path not in cache:
|
| 112 |
+
continue
|
| 113 |
+
detail = cache[path].copy()
|
| 114 |
+
|
| 115 |
+
if cfs is not None:
|
| 116 |
+
if cfs.check_files and detail["uid"] != cfs.fs.ukey(path):
|
| 117 |
+
# Wrong file as determined by hash of file properties
|
| 118 |
+
continue
|
| 119 |
+
if cfs.expiry and time.time() - detail["time"] > cfs.expiry:
|
| 120 |
+
# Cached file has expired
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
fn = os.path.join(base, detail["fn"])
|
| 124 |
+
if os.path.exists(fn):
|
| 125 |
+
return detail, fn
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
def clear_expired(self, expiry_time: int) -> tuple[list[str], bool]:
|
| 129 |
+
"""Remove expired metadata from the cache.
|
| 130 |
+
|
| 131 |
+
Returns names of files corresponding to expired metadata and a boolean
|
| 132 |
+
flag indicating whether the writable cache is empty. Caller is
|
| 133 |
+
responsible for deleting the expired files.
|
| 134 |
+
"""
|
| 135 |
+
expired_files = []
|
| 136 |
+
for path, detail in self.cached_files[-1].copy().items():
|
| 137 |
+
if time.time() - detail["time"] > expiry_time:
|
| 138 |
+
fn = detail.get("fn", "")
|
| 139 |
+
if not fn:
|
| 140 |
+
raise RuntimeError(
|
| 141 |
+
f"Cache metadata does not contain 'fn' for {path}"
|
| 142 |
+
)
|
| 143 |
+
fn = os.path.join(self._storage[-1], fn)
|
| 144 |
+
expired_files.append(fn)
|
| 145 |
+
self.cached_files[-1].pop(path)
|
| 146 |
+
|
| 147 |
+
if self.cached_files[-1]:
|
| 148 |
+
cache_path = os.path.join(self._storage[-1], "cache")
|
| 149 |
+
self._save(self.cached_files[-1], cache_path)
|
| 150 |
+
|
| 151 |
+
writable_cache_empty = not self.cached_files[-1]
|
| 152 |
+
return expired_files, writable_cache_empty
|
| 153 |
+
|
| 154 |
+
def load(self) -> None:
|
| 155 |
+
"""Load all metadata from disk and store in ``self.cached_files``"""
|
| 156 |
+
cached_files = []
|
| 157 |
+
for fn, _, _ in self._scan_locations():
|
| 158 |
+
if os.path.exists(fn):
|
| 159 |
+
# TODO: consolidate blocks here
|
| 160 |
+
cached_files.append(self._load(fn))
|
| 161 |
+
else:
|
| 162 |
+
cached_files.append({})
|
| 163 |
+
self.cached_files = cached_files or [{}]
|
| 164 |
+
|
| 165 |
+
def on_close_cached_file(self, f: Any, path: str) -> None:
|
| 166 |
+
"""Perform side-effect actions on closing a cached file.
|
| 167 |
+
|
| 168 |
+
The actual closing of the file is the responsibility of the caller.
|
| 169 |
+
"""
|
| 170 |
+
# File must be writeble, so in self.cached_files[-1]
|
| 171 |
+
c = self.cached_files[-1][path]
|
| 172 |
+
if c["blocks"] is not True and len(c["blocks"]) * f.blocksize >= f.size:
|
| 173 |
+
c["blocks"] = True
|
| 174 |
+
|
| 175 |
+
def pop_file(self, path: str) -> str | None:
|
| 176 |
+
"""Remove metadata of cached file.
|
| 177 |
+
|
| 178 |
+
If path is in the cache, return the filename of the cached file,
|
| 179 |
+
otherwise return ``None``. Caller is responsible for deleting the
|
| 180 |
+
cached file.
|
| 181 |
+
"""
|
| 182 |
+
details = self.check_file(path, None)
|
| 183 |
+
if not details:
|
| 184 |
+
return None
|
| 185 |
+
_, fn = details
|
| 186 |
+
if fn.startswith(self._storage[-1]):
|
| 187 |
+
self.cached_files[-1].pop(path)
|
| 188 |
+
self.save()
|
| 189 |
+
else:
|
| 190 |
+
raise PermissionError(
|
| 191 |
+
"Can only delete cached file in last, writable cache location"
|
| 192 |
+
)
|
| 193 |
+
return fn
|
| 194 |
+
|
| 195 |
+
def save(self) -> None:
|
| 196 |
+
"""Save metadata to disk"""
|
| 197 |
+
for (fn, _, writable), cache in zip(self._scan_locations(), self.cached_files):
|
| 198 |
+
if not writable:
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
if os.path.exists(fn):
|
| 202 |
+
cached_files = self._load(fn)
|
| 203 |
+
for k, c in cached_files.items():
|
| 204 |
+
if k in cache:
|
| 205 |
+
if c["blocks"] is True or cache[k]["blocks"] is True:
|
| 206 |
+
c["blocks"] = True
|
| 207 |
+
else:
|
| 208 |
+
# self.cached_files[*][*]["blocks"] must continue to
|
| 209 |
+
# point to the same set object so that updates
|
| 210 |
+
# performed by MMapCache are propagated back to
|
| 211 |
+
# self.cached_files.
|
| 212 |
+
blocks = cache[k]["blocks"]
|
| 213 |
+
blocks.update(c["blocks"])
|
| 214 |
+
c["blocks"] = blocks
|
| 215 |
+
c["time"] = max(c["time"], cache[k]["time"])
|
| 216 |
+
c["uid"] = cache[k]["uid"]
|
| 217 |
+
|
| 218 |
+
# Files can be added to cache after it was written once
|
| 219 |
+
for k, c in cache.items():
|
| 220 |
+
if k not in cached_files:
|
| 221 |
+
cached_files[k] = c
|
| 222 |
+
else:
|
| 223 |
+
cached_files = cache
|
| 224 |
+
cache = {k: v.copy() for k, v in cached_files.items()}
|
| 225 |
+
for c in cache.values():
|
| 226 |
+
if isinstance(c["blocks"], set):
|
| 227 |
+
c["blocks"] = list(c["blocks"])
|
| 228 |
+
self._save(cache, fn)
|
| 229 |
+
self.cached_files[-1] = cached_files
|
| 230 |
+
|
| 231 |
+
def update_file(self, path: str, detail: Detail) -> None:
|
| 232 |
+
"""Update metadata for specific file in memory, do not save"""
|
| 233 |
+
self.cached_files[-1][path] = detail
|
pythonProject/.venv/Lib/site-packages/fsspec/implementations/cached.py
ADDED
|
@@ -0,0 +1,998 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
import time
|
| 8 |
+
import weakref
|
| 9 |
+
from shutil import rmtree
|
| 10 |
+
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
| 11 |
+
|
| 12 |
+
from fsspec import AbstractFileSystem, filesystem
|
| 13 |
+
from fsspec.callbacks import DEFAULT_CALLBACK
|
| 14 |
+
from fsspec.compression import compr
|
| 15 |
+
from fsspec.core import BaseCache, MMapCache
|
| 16 |
+
from fsspec.exceptions import BlocksizeMismatchError
|
| 17 |
+
from fsspec.implementations.cache_mapper import create_cache_mapper
|
| 18 |
+
from fsspec.implementations.cache_metadata import CacheMetadata
|
| 19 |
+
from fsspec.implementations.local import LocalFileSystem
|
| 20 |
+
from fsspec.spec import AbstractBufferedFile
|
| 21 |
+
from fsspec.transaction import Transaction
|
| 22 |
+
from fsspec.utils import infer_compression
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from fsspec.implementations.cache_mapper import AbstractCacheMapper
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger("fsspec.cached")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class WriteCachedTransaction(Transaction):
|
| 31 |
+
def complete(self, commit=True):
|
| 32 |
+
rpaths = [f.path for f in self.files]
|
| 33 |
+
lpaths = [f.fn for f in self.files]
|
| 34 |
+
if commit:
|
| 35 |
+
self.fs.put(lpaths, rpaths)
|
| 36 |
+
self.files.clear()
|
| 37 |
+
self.fs._intrans = False
|
| 38 |
+
self.fs._transaction = None
|
| 39 |
+
self.fs = None # break cycle
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CachingFileSystem(AbstractFileSystem):
|
| 43 |
+
"""Locally caching filesystem, layer over any other FS
|
| 44 |
+
|
| 45 |
+
This class implements chunk-wise local storage of remote files, for quick
|
| 46 |
+
access after the initial download. The files are stored in a given
|
| 47 |
+
directory with hashes of URLs for the filenames. If no directory is given,
|
| 48 |
+
a temporary one is used, which should be cleaned up by the OS after the
|
| 49 |
+
process ends. The files themselves are sparse (as implemented in
|
| 50 |
+
:class:`~fsspec.caching.MMapCache`), so only the data which is accessed
|
| 51 |
+
takes up space.
|
| 52 |
+
|
| 53 |
+
Restrictions:
|
| 54 |
+
|
| 55 |
+
- the block-size must be the same for each access of a given file, unless
|
| 56 |
+
all blocks of the file have already been read
|
| 57 |
+
- caching can only be applied to file-systems which produce files
|
| 58 |
+
derived from fsspec.spec.AbstractBufferedFile ; LocalFileSystem is also
|
| 59 |
+
allowed, for testing
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
protocol: ClassVar[str | tuple[str, ...]] = ("blockcache", "cached")
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
target_protocol=None,
|
| 67 |
+
cache_storage="TMP",
|
| 68 |
+
cache_check=10,
|
| 69 |
+
check_files=False,
|
| 70 |
+
expiry_time=604800,
|
| 71 |
+
target_options=None,
|
| 72 |
+
fs=None,
|
| 73 |
+
same_names: bool | None = None,
|
| 74 |
+
compression=None,
|
| 75 |
+
cache_mapper: AbstractCacheMapper | None = None,
|
| 76 |
+
**kwargs,
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
Parameters
|
| 81 |
+
----------
|
| 82 |
+
target_protocol: str (optional)
|
| 83 |
+
Target filesystem protocol. Provide either this or ``fs``.
|
| 84 |
+
cache_storage: str or list(str)
|
| 85 |
+
Location to store files. If "TMP", this is a temporary directory,
|
| 86 |
+
and will be cleaned up by the OS when this process ends (or later).
|
| 87 |
+
If a list, each location will be tried in the order given, but
|
| 88 |
+
only the last will be considered writable.
|
| 89 |
+
cache_check: int
|
| 90 |
+
Number of seconds between reload of cache metadata
|
| 91 |
+
check_files: bool
|
| 92 |
+
Whether to explicitly see if the UID of the remote file matches
|
| 93 |
+
the stored one before using. Warning: some file systems such as
|
| 94 |
+
HTTP cannot reliably give a unique hash of the contents of some
|
| 95 |
+
path, so be sure to set this option to False.
|
| 96 |
+
expiry_time: int
|
| 97 |
+
The time in seconds after which a local copy is considered useless.
|
| 98 |
+
Set to falsy to prevent expiry. The default is equivalent to one
|
| 99 |
+
week.
|
| 100 |
+
target_options: dict or None
|
| 101 |
+
Passed to the instantiation of the FS, if fs is None.
|
| 102 |
+
fs: filesystem instance
|
| 103 |
+
The target filesystem to run against. Provide this or ``protocol``.
|
| 104 |
+
same_names: bool (optional)
|
| 105 |
+
By default, target URLs are hashed using a ``HashCacheMapper`` so
|
| 106 |
+
that files from different backends with the same basename do not
|
| 107 |
+
conflict. If this argument is ``true``, a ``BasenameCacheMapper``
|
| 108 |
+
is used instead. Other cache mapper options are available by using
|
| 109 |
+
the ``cache_mapper`` keyword argument. Only one of this and
|
| 110 |
+
``cache_mapper`` should be specified.
|
| 111 |
+
compression: str (optional)
|
| 112 |
+
To decompress on download. Can be 'infer' (guess from the URL name),
|
| 113 |
+
one of the entries in ``fsspec.compression.compr``, or None for no
|
| 114 |
+
decompression.
|
| 115 |
+
cache_mapper: AbstractCacheMapper (optional)
|
| 116 |
+
The object use to map from original filenames to cached filenames.
|
| 117 |
+
Only one of this and ``same_names`` should be specified.
|
| 118 |
+
"""
|
| 119 |
+
super().__init__(**kwargs)
|
| 120 |
+
if fs is None and target_protocol is None:
|
| 121 |
+
raise ValueError(
|
| 122 |
+
"Please provide filesystem instance(fs) or target_protocol"
|
| 123 |
+
)
|
| 124 |
+
if not (fs is None) ^ (target_protocol is None):
|
| 125 |
+
raise ValueError(
|
| 126 |
+
"Both filesystems (fs) and target_protocol may not be both given."
|
| 127 |
+
)
|
| 128 |
+
if cache_storage == "TMP":
|
| 129 |
+
tempdir = tempfile.mkdtemp()
|
| 130 |
+
storage = [tempdir]
|
| 131 |
+
weakref.finalize(self, self._remove_tempdir, tempdir)
|
| 132 |
+
else:
|
| 133 |
+
if isinstance(cache_storage, str):
|
| 134 |
+
storage = [cache_storage]
|
| 135 |
+
else:
|
| 136 |
+
storage = cache_storage
|
| 137 |
+
os.makedirs(storage[-1], exist_ok=True)
|
| 138 |
+
self.storage = storage
|
| 139 |
+
self.kwargs = target_options or {}
|
| 140 |
+
self.cache_check = cache_check
|
| 141 |
+
self.check_files = check_files
|
| 142 |
+
self.expiry = expiry_time
|
| 143 |
+
self.compression = compression
|
| 144 |
+
|
| 145 |
+
# Size of cache in bytes. If None then the size is unknown and will be
|
| 146 |
+
# recalculated the next time cache_size() is called. On writes to the
|
| 147 |
+
# cache this is reset to None.
|
| 148 |
+
self._cache_size = None
|
| 149 |
+
|
| 150 |
+
if same_names is not None and cache_mapper is not None:
|
| 151 |
+
raise ValueError(
|
| 152 |
+
"Cannot specify both same_names and cache_mapper in "
|
| 153 |
+
"CachingFileSystem.__init__"
|
| 154 |
+
)
|
| 155 |
+
if cache_mapper is not None:
|
| 156 |
+
self._mapper = cache_mapper
|
| 157 |
+
else:
|
| 158 |
+
self._mapper = create_cache_mapper(
|
| 159 |
+
same_names if same_names is not None else False
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.target_protocol = (
|
| 163 |
+
target_protocol
|
| 164 |
+
if isinstance(target_protocol, str)
|
| 165 |
+
else (fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0])
|
| 166 |
+
)
|
| 167 |
+
self._metadata = CacheMetadata(self.storage)
|
| 168 |
+
self.load_cache()
|
| 169 |
+
self.fs = fs if fs is not None else filesystem(target_protocol, **self.kwargs)
|
| 170 |
+
|
| 171 |
+
def _strip_protocol(path):
|
| 172 |
+
# acts as a method, since each instance has a difference target
|
| 173 |
+
return self.fs._strip_protocol(type(self)._strip_protocol(path))
|
| 174 |
+
|
| 175 |
+
self._strip_protocol: Callable = _strip_protocol
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def _remove_tempdir(tempdir):
|
| 179 |
+
try:
|
| 180 |
+
rmtree(tempdir)
|
| 181 |
+
except Exception:
|
| 182 |
+
pass
|
| 183 |
+
|
| 184 |
+
def _mkcache(self):
|
| 185 |
+
os.makedirs(self.storage[-1], exist_ok=True)
|
| 186 |
+
|
| 187 |
+
def cache_size(self):
|
| 188 |
+
"""Return size of cache in bytes.
|
| 189 |
+
|
| 190 |
+
If more than one cache directory is in use, only the size of the last
|
| 191 |
+
one (the writable cache directory) is returned.
|
| 192 |
+
"""
|
| 193 |
+
if self._cache_size is None:
|
| 194 |
+
cache_dir = self.storage[-1]
|
| 195 |
+
self._cache_size = filesystem("file").du(cache_dir, withdirs=True)
|
| 196 |
+
return self._cache_size
|
| 197 |
+
|
| 198 |
+
def load_cache(self):
|
| 199 |
+
"""Read set of stored blocks from file"""
|
| 200 |
+
self._metadata.load()
|
| 201 |
+
self._mkcache()
|
| 202 |
+
self.last_cache = time.time()
|
| 203 |
+
|
| 204 |
+
def save_cache(self):
|
| 205 |
+
"""Save set of stored blocks from file"""
|
| 206 |
+
self._mkcache()
|
| 207 |
+
self._metadata.save()
|
| 208 |
+
self.last_cache = time.time()
|
| 209 |
+
self._cache_size = None
|
| 210 |
+
|
| 211 |
+
def _check_cache(self):
|
| 212 |
+
"""Reload caches if time elapsed or any disappeared"""
|
| 213 |
+
self._mkcache()
|
| 214 |
+
if not self.cache_check:
|
| 215 |
+
# explicitly told not to bother checking
|
| 216 |
+
return
|
| 217 |
+
timecond = time.time() - self.last_cache > self.cache_check
|
| 218 |
+
existcond = all(os.path.exists(storage) for storage in self.storage)
|
| 219 |
+
if timecond or not existcond:
|
| 220 |
+
self.load_cache()
|
| 221 |
+
|
| 222 |
+
def _check_file(self, path):
|
| 223 |
+
"""Is path in cache and still valid"""
|
| 224 |
+
path = self._strip_protocol(path)
|
| 225 |
+
self._check_cache()
|
| 226 |
+
return self._metadata.check_file(path, self)
|
| 227 |
+
|
| 228 |
+
def clear_cache(self):
|
| 229 |
+
"""Remove all files and metadata from the cache
|
| 230 |
+
|
| 231 |
+
In the case of multiple cache locations, this clears only the last one,
|
| 232 |
+
which is assumed to be the read/write one.
|
| 233 |
+
"""
|
| 234 |
+
rmtree(self.storage[-1])
|
| 235 |
+
self.load_cache()
|
| 236 |
+
self._cache_size = None
|
| 237 |
+
|
| 238 |
+
def clear_expired_cache(self, expiry_time=None):
|
| 239 |
+
"""Remove all expired files and metadata from the cache
|
| 240 |
+
|
| 241 |
+
In the case of multiple cache locations, this clears only the last one,
|
| 242 |
+
which is assumed to be the read/write one.
|
| 243 |
+
|
| 244 |
+
Parameters
|
| 245 |
+
----------
|
| 246 |
+
expiry_time: int
|
| 247 |
+
The time in seconds after which a local copy is considered useless.
|
| 248 |
+
If not defined the default is equivalent to the attribute from the
|
| 249 |
+
file caching instantiation.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
if not expiry_time:
|
| 253 |
+
expiry_time = self.expiry
|
| 254 |
+
|
| 255 |
+
self._check_cache()
|
| 256 |
+
|
| 257 |
+
expired_files, writable_cache_empty = self._metadata.clear_expired(expiry_time)
|
| 258 |
+
for fn in expired_files:
|
| 259 |
+
if os.path.exists(fn):
|
| 260 |
+
os.remove(fn)
|
| 261 |
+
|
| 262 |
+
if writable_cache_empty:
|
| 263 |
+
rmtree(self.storage[-1])
|
| 264 |
+
self.load_cache()
|
| 265 |
+
|
| 266 |
+
self._cache_size = None
|
| 267 |
+
|
| 268 |
+
def pop_from_cache(self, path):
|
| 269 |
+
"""Remove cached version of given file
|
| 270 |
+
|
| 271 |
+
Deletes local copy of the given (remote) path. If it is found in a cache
|
| 272 |
+
location which is not the last, it is assumed to be read-only, and
|
| 273 |
+
raises PermissionError
|
| 274 |
+
"""
|
| 275 |
+
path = self._strip_protocol(path)
|
| 276 |
+
fn = self._metadata.pop_file(path)
|
| 277 |
+
if fn is not None:
|
| 278 |
+
os.remove(fn)
|
| 279 |
+
self._cache_size = None
|
| 280 |
+
|
| 281 |
+
def _open(
|
| 282 |
+
self,
|
| 283 |
+
path,
|
| 284 |
+
mode="rb",
|
| 285 |
+
block_size=None,
|
| 286 |
+
autocommit=True,
|
| 287 |
+
cache_options=None,
|
| 288 |
+
**kwargs,
|
| 289 |
+
):
|
| 290 |
+
"""Wrap the target _open
|
| 291 |
+
|
| 292 |
+
If the whole file exists in the cache, just open it locally and
|
| 293 |
+
return that.
|
| 294 |
+
|
| 295 |
+
Otherwise, open the file on the target FS, and make it have a mmap
|
| 296 |
+
cache pointing to the location which we determine, in our cache.
|
| 297 |
+
The ``blocks`` instance is shared, so as the mmap cache instance
|
| 298 |
+
updates, so does the entry in our ``cached_files`` attribute.
|
| 299 |
+
We monkey-patch this file, so that when it closes, we call
|
| 300 |
+
``close_and_update`` to save the state of the blocks.
|
| 301 |
+
"""
|
| 302 |
+
path = self._strip_protocol(path)
|
| 303 |
+
|
| 304 |
+
path = self.fs._strip_protocol(path)
|
| 305 |
+
if "r" not in mode:
|
| 306 |
+
return self.fs._open(
|
| 307 |
+
path,
|
| 308 |
+
mode=mode,
|
| 309 |
+
block_size=block_size,
|
| 310 |
+
autocommit=autocommit,
|
| 311 |
+
cache_options=cache_options,
|
| 312 |
+
**kwargs,
|
| 313 |
+
)
|
| 314 |
+
detail = self._check_file(path)
|
| 315 |
+
if detail:
|
| 316 |
+
# file is in cache
|
| 317 |
+
detail, fn = detail
|
| 318 |
+
hash, blocks = detail["fn"], detail["blocks"]
|
| 319 |
+
if blocks is True:
|
| 320 |
+
# stored file is complete
|
| 321 |
+
logger.debug("Opening local copy of %s", path)
|
| 322 |
+
return open(fn, mode)
|
| 323 |
+
# TODO: action where partial file exists in read-only cache
|
| 324 |
+
logger.debug("Opening partially cached copy of %s", path)
|
| 325 |
+
else:
|
| 326 |
+
hash = self._mapper(path)
|
| 327 |
+
fn = os.path.join(self.storage[-1], hash)
|
| 328 |
+
blocks = set()
|
| 329 |
+
detail = {
|
| 330 |
+
"original": path,
|
| 331 |
+
"fn": hash,
|
| 332 |
+
"blocks": blocks,
|
| 333 |
+
"time": time.time(),
|
| 334 |
+
"uid": self.fs.ukey(path),
|
| 335 |
+
}
|
| 336 |
+
self._metadata.update_file(path, detail)
|
| 337 |
+
logger.debug("Creating local sparse file for %s", path)
|
| 338 |
+
|
| 339 |
+
# explicitly submitting the size to the open call will avoid extra
|
| 340 |
+
# operations when opening. This is particularly relevant
|
| 341 |
+
# for any file that is read over a network, e.g. S3.
|
| 342 |
+
size = detail.get("size")
|
| 343 |
+
|
| 344 |
+
# call target filesystems open
|
| 345 |
+
self._mkcache()
|
| 346 |
+
f = self.fs._open(
|
| 347 |
+
path,
|
| 348 |
+
mode=mode,
|
| 349 |
+
block_size=block_size,
|
| 350 |
+
autocommit=autocommit,
|
| 351 |
+
cache_options=cache_options,
|
| 352 |
+
cache_type="none",
|
| 353 |
+
size=size,
|
| 354 |
+
**kwargs,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# set size if not already set
|
| 358 |
+
if size is None:
|
| 359 |
+
detail["size"] = f.size
|
| 360 |
+
self._metadata.update_file(path, detail)
|
| 361 |
+
|
| 362 |
+
if self.compression:
|
| 363 |
+
comp = (
|
| 364 |
+
infer_compression(path)
|
| 365 |
+
if self.compression == "infer"
|
| 366 |
+
else self.compression
|
| 367 |
+
)
|
| 368 |
+
f = compr[comp](f, mode="rb")
|
| 369 |
+
if "blocksize" in detail:
|
| 370 |
+
if detail["blocksize"] != f.blocksize:
|
| 371 |
+
raise BlocksizeMismatchError(
|
| 372 |
+
f"Cached file must be reopened with same block"
|
| 373 |
+
f" size as original (old: {detail['blocksize']},"
|
| 374 |
+
f" new {f.blocksize})"
|
| 375 |
+
)
|
| 376 |
+
else:
|
| 377 |
+
detail["blocksize"] = f.blocksize
|
| 378 |
+
|
| 379 |
+
def _fetch_ranges(ranges):
|
| 380 |
+
return self.fs.cat_ranges(
|
| 381 |
+
[path] * len(ranges),
|
| 382 |
+
[r[0] for r in ranges],
|
| 383 |
+
[r[1] for r in ranges],
|
| 384 |
+
**kwargs,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
multi_fetcher = None if self.compression else _fetch_ranges
|
| 388 |
+
f.cache = MMapCache(
|
| 389 |
+
f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher
|
| 390 |
+
)
|
| 391 |
+
close = f.close
|
| 392 |
+
f.close = lambda: self.close_and_update(f, close)
|
| 393 |
+
self.save_cache()
|
| 394 |
+
return f
|
| 395 |
+
|
| 396 |
+
def _parent(self, path):
|
| 397 |
+
return self.fs._parent(path)
|
| 398 |
+
|
| 399 |
+
def hash_name(self, path: str, *args: Any) -> str:
|
| 400 |
+
# Kept for backward compatibility with downstream libraries.
|
| 401 |
+
# Ignores extra arguments, previously same_name boolean.
|
| 402 |
+
return self._mapper(path)
|
| 403 |
+
|
| 404 |
+
def close_and_update(self, f, close):
|
| 405 |
+
"""Called when a file is closing, so store the set of blocks"""
|
| 406 |
+
if f.closed:
|
| 407 |
+
return
|
| 408 |
+
path = self._strip_protocol(f.path)
|
| 409 |
+
self._metadata.on_close_cached_file(f, path)
|
| 410 |
+
try:
|
| 411 |
+
logger.debug("going to save")
|
| 412 |
+
self.save_cache()
|
| 413 |
+
logger.debug("saved")
|
| 414 |
+
except OSError:
|
| 415 |
+
logger.debug("Cache saving failed while closing file")
|
| 416 |
+
except NameError:
|
| 417 |
+
logger.debug("Cache save failed due to interpreter shutdown")
|
| 418 |
+
close()
|
| 419 |
+
f.closed = True
|
| 420 |
+
|
| 421 |
+
def ls(self, path, detail=True):
|
| 422 |
+
return self.fs.ls(path, detail)
|
| 423 |
+
|
| 424 |
+
def __getattribute__(self, item):
|
| 425 |
+
if item in {
|
| 426 |
+
"load_cache",
|
| 427 |
+
"_open",
|
| 428 |
+
"save_cache",
|
| 429 |
+
"close_and_update",
|
| 430 |
+
"__init__",
|
| 431 |
+
"__getattribute__",
|
| 432 |
+
"__reduce__",
|
| 433 |
+
"_make_local_details",
|
| 434 |
+
"open",
|
| 435 |
+
"cat",
|
| 436 |
+
"cat_file",
|
| 437 |
+
"_cat_file",
|
| 438 |
+
"cat_ranges",
|
| 439 |
+
"_cat_ranges",
|
| 440 |
+
"get",
|
| 441 |
+
"read_block",
|
| 442 |
+
"tail",
|
| 443 |
+
"head",
|
| 444 |
+
"info",
|
| 445 |
+
"ls",
|
| 446 |
+
"exists",
|
| 447 |
+
"isfile",
|
| 448 |
+
"isdir",
|
| 449 |
+
"_check_file",
|
| 450 |
+
"_check_cache",
|
| 451 |
+
"_mkcache",
|
| 452 |
+
"clear_cache",
|
| 453 |
+
"clear_expired_cache",
|
| 454 |
+
"pop_from_cache",
|
| 455 |
+
"local_file",
|
| 456 |
+
"_paths_from_path",
|
| 457 |
+
"get_mapper",
|
| 458 |
+
"open_many",
|
| 459 |
+
"commit_many",
|
| 460 |
+
"hash_name",
|
| 461 |
+
"__hash__",
|
| 462 |
+
"__eq__",
|
| 463 |
+
"to_json",
|
| 464 |
+
"to_dict",
|
| 465 |
+
"cache_size",
|
| 466 |
+
"pipe_file",
|
| 467 |
+
"pipe",
|
| 468 |
+
"start_transaction",
|
| 469 |
+
"end_transaction",
|
| 470 |
+
}:
|
| 471 |
+
# all the methods defined in this class. Note `open` here, since
|
| 472 |
+
# it calls `_open`, but is actually in superclass
|
| 473 |
+
return lambda *args, **kw: getattr(type(self), item).__get__(self)(
|
| 474 |
+
*args, **kw
|
| 475 |
+
)
|
| 476 |
+
if item in ["__reduce_ex__"]:
|
| 477 |
+
raise AttributeError
|
| 478 |
+
if item in ["transaction"]:
|
| 479 |
+
# property
|
| 480 |
+
return type(self).transaction.__get__(self)
|
| 481 |
+
if item in {"_cache", "transaction_type", "protocol"}:
|
| 482 |
+
# class attributes
|
| 483 |
+
return getattr(type(self), item)
|
| 484 |
+
if item == "__class__":
|
| 485 |
+
return type(self)
|
| 486 |
+
d = object.__getattribute__(self, "__dict__")
|
| 487 |
+
fs = d.get("fs", None) # fs is not immediately defined
|
| 488 |
+
if item in d:
|
| 489 |
+
return d[item]
|
| 490 |
+
elif fs is not None:
|
| 491 |
+
if item in fs.__dict__:
|
| 492 |
+
# attribute of instance
|
| 493 |
+
return fs.__dict__[item]
|
| 494 |
+
# attributed belonging to the target filesystem
|
| 495 |
+
cls = type(fs)
|
| 496 |
+
m = getattr(cls, item)
|
| 497 |
+
if (inspect.isfunction(m) or inspect.isdatadescriptor(m)) and (
|
| 498 |
+
not hasattr(m, "__self__") or m.__self__ is None
|
| 499 |
+
):
|
| 500 |
+
# instance method
|
| 501 |
+
return m.__get__(fs, cls)
|
| 502 |
+
return m # class method or attribute
|
| 503 |
+
else:
|
| 504 |
+
# attributes of the superclass, while target is being set up
|
| 505 |
+
return super().__getattribute__(item)
|
| 506 |
+
|
| 507 |
+
def __eq__(self, other):
|
| 508 |
+
"""Test for equality."""
|
| 509 |
+
if self is other:
|
| 510 |
+
return True
|
| 511 |
+
if not isinstance(other, type(self)):
|
| 512 |
+
return False
|
| 513 |
+
return (
|
| 514 |
+
self.storage == other.storage
|
| 515 |
+
and self.kwargs == other.kwargs
|
| 516 |
+
and self.cache_check == other.cache_check
|
| 517 |
+
and self.check_files == other.check_files
|
| 518 |
+
and self.expiry == other.expiry
|
| 519 |
+
and self.compression == other.compression
|
| 520 |
+
and self._mapper == other._mapper
|
| 521 |
+
and self.target_protocol == other.target_protocol
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
def __hash__(self):
|
| 525 |
+
"""Calculate hash."""
|
| 526 |
+
return (
|
| 527 |
+
hash(tuple(self.storage))
|
| 528 |
+
^ hash(str(self.kwargs))
|
| 529 |
+
^ hash(self.cache_check)
|
| 530 |
+
^ hash(self.check_files)
|
| 531 |
+
^ hash(self.expiry)
|
| 532 |
+
^ hash(self.compression)
|
| 533 |
+
^ hash(self._mapper)
|
| 534 |
+
^ hash(self.target_protocol)
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class WholeFileCacheFileSystem(CachingFileSystem):
|
| 539 |
+
"""Caches whole remote files on first access
|
| 540 |
+
|
| 541 |
+
This class is intended as a layer over any other file system, and
|
| 542 |
+
will make a local copy of each file accessed, so that all subsequent
|
| 543 |
+
reads are local. This is similar to ``CachingFileSystem``, but without
|
| 544 |
+
the block-wise functionality and so can work even when sparse files
|
| 545 |
+
are not allowed. See its docstring for definition of the init
|
| 546 |
+
arguments.
|
| 547 |
+
|
| 548 |
+
The class still needs access to the remote store for listing files,
|
| 549 |
+
and may refresh cached files.
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
protocol = "filecache"
|
| 553 |
+
local_file = True
|
| 554 |
+
|
| 555 |
+
def open_many(self, open_files, **kwargs):
|
| 556 |
+
paths = [of.path for of in open_files]
|
| 557 |
+
if "r" in open_files.mode:
|
| 558 |
+
self._mkcache()
|
| 559 |
+
else:
|
| 560 |
+
return [
|
| 561 |
+
LocalTempFile(
|
| 562 |
+
self.fs,
|
| 563 |
+
path,
|
| 564 |
+
mode=open_files.mode,
|
| 565 |
+
fn=os.path.join(self.storage[-1], self._mapper(path)),
|
| 566 |
+
**kwargs,
|
| 567 |
+
)
|
| 568 |
+
for path in paths
|
| 569 |
+
]
|
| 570 |
+
|
| 571 |
+
if self.compression:
|
| 572 |
+
raise NotImplementedError
|
| 573 |
+
details = [self._check_file(sp) for sp in paths]
|
| 574 |
+
downpath = [p for p, d in zip(paths, details) if not d]
|
| 575 |
+
downfn0 = [
|
| 576 |
+
os.path.join(self.storage[-1], self._mapper(p))
|
| 577 |
+
for p, d in zip(paths, details)
|
| 578 |
+
] # keep these path names for opening later
|
| 579 |
+
downfn = [fn for fn, d in zip(downfn0, details) if not d]
|
| 580 |
+
if downpath:
|
| 581 |
+
# skip if all files are already cached and up to date
|
| 582 |
+
self.fs.get(downpath, downfn)
|
| 583 |
+
|
| 584 |
+
# update metadata - only happens when downloads are successful
|
| 585 |
+
newdetail = [
|
| 586 |
+
{
|
| 587 |
+
"original": path,
|
| 588 |
+
"fn": self._mapper(path),
|
| 589 |
+
"blocks": True,
|
| 590 |
+
"time": time.time(),
|
| 591 |
+
"uid": self.fs.ukey(path),
|
| 592 |
+
}
|
| 593 |
+
for path in downpath
|
| 594 |
+
]
|
| 595 |
+
for path, detail in zip(downpath, newdetail):
|
| 596 |
+
self._metadata.update_file(path, detail)
|
| 597 |
+
self.save_cache()
|
| 598 |
+
|
| 599 |
+
def firstpart(fn):
|
| 600 |
+
# helper to adapt both whole-file and simple-cache
|
| 601 |
+
return fn[1] if isinstance(fn, tuple) else fn
|
| 602 |
+
|
| 603 |
+
return [
|
| 604 |
+
open(firstpart(fn0) if fn0 else fn1, mode=open_files.mode)
|
| 605 |
+
for fn0, fn1 in zip(details, downfn0)
|
| 606 |
+
]
|
| 607 |
+
|
| 608 |
+
def commit_many(self, open_files):
|
| 609 |
+
self.fs.put([f.fn for f in open_files], [f.path for f in open_files])
|
| 610 |
+
[f.close() for f in open_files]
|
| 611 |
+
for f in open_files:
|
| 612 |
+
# in case autocommit is off, and so close did not already delete
|
| 613 |
+
try:
|
| 614 |
+
os.remove(f.name)
|
| 615 |
+
except FileNotFoundError:
|
| 616 |
+
pass
|
| 617 |
+
self._cache_size = None
|
| 618 |
+
|
| 619 |
+
def _make_local_details(self, path):
|
| 620 |
+
hash = self._mapper(path)
|
| 621 |
+
fn = os.path.join(self.storage[-1], hash)
|
| 622 |
+
detail = {
|
| 623 |
+
"original": path,
|
| 624 |
+
"fn": hash,
|
| 625 |
+
"blocks": True,
|
| 626 |
+
"time": time.time(),
|
| 627 |
+
"uid": self.fs.ukey(path),
|
| 628 |
+
}
|
| 629 |
+
self._metadata.update_file(path, detail)
|
| 630 |
+
logger.debug("Copying %s to local cache", path)
|
| 631 |
+
return fn
|
| 632 |
+
|
| 633 |
+
def cat(
|
| 634 |
+
self,
|
| 635 |
+
path,
|
| 636 |
+
recursive=False,
|
| 637 |
+
on_error="raise",
|
| 638 |
+
callback=DEFAULT_CALLBACK,
|
| 639 |
+
**kwargs,
|
| 640 |
+
):
|
| 641 |
+
paths = self.expand_path(
|
| 642 |
+
path, recursive=recursive, maxdepth=kwargs.get("maxdepth")
|
| 643 |
+
)
|
| 644 |
+
getpaths = []
|
| 645 |
+
storepaths = []
|
| 646 |
+
fns = []
|
| 647 |
+
out = {}
|
| 648 |
+
for p in paths.copy():
|
| 649 |
+
try:
|
| 650 |
+
detail = self._check_file(p)
|
| 651 |
+
if not detail:
|
| 652 |
+
fn = self._make_local_details(p)
|
| 653 |
+
getpaths.append(p)
|
| 654 |
+
storepaths.append(fn)
|
| 655 |
+
else:
|
| 656 |
+
detail, fn = detail if isinstance(detail, tuple) else (None, detail)
|
| 657 |
+
fns.append(fn)
|
| 658 |
+
except Exception as e:
|
| 659 |
+
if on_error == "raise":
|
| 660 |
+
raise
|
| 661 |
+
if on_error == "return":
|
| 662 |
+
out[p] = e
|
| 663 |
+
paths.remove(p)
|
| 664 |
+
|
| 665 |
+
if getpaths:
|
| 666 |
+
self.fs.get(getpaths, storepaths)
|
| 667 |
+
self.save_cache()
|
| 668 |
+
|
| 669 |
+
callback.set_size(len(paths))
|
| 670 |
+
for p, fn in zip(paths, fns):
|
| 671 |
+
with open(fn, "rb") as f:
|
| 672 |
+
out[p] = f.read()
|
| 673 |
+
callback.relative_update(1)
|
| 674 |
+
if isinstance(path, str) and len(paths) == 1 and recursive is False:
|
| 675 |
+
out = out[paths[0]]
|
| 676 |
+
return out
|
| 677 |
+
|
| 678 |
+
def _open(self, path, mode="rb", **kwargs):
|
| 679 |
+
path = self._strip_protocol(path)
|
| 680 |
+
if "r" not in mode:
|
| 681 |
+
hash = self._mapper(path)
|
| 682 |
+
fn = os.path.join(self.storage[-1], hash)
|
| 683 |
+
user_specified_kwargs = {
|
| 684 |
+
k: v
|
| 685 |
+
for k, v in kwargs.items()
|
| 686 |
+
# those kwargs were added by open(), we don't want them
|
| 687 |
+
if k not in ["autocommit", "block_size", "cache_options"]
|
| 688 |
+
}
|
| 689 |
+
return LocalTempFile(self, path, mode=mode, fn=fn, **user_specified_kwargs)
|
| 690 |
+
detail = self._check_file(path)
|
| 691 |
+
if detail:
|
| 692 |
+
detail, fn = detail
|
| 693 |
+
_, blocks = detail["fn"], detail["blocks"]
|
| 694 |
+
if blocks is True:
|
| 695 |
+
logger.debug("Opening local copy of %s", path)
|
| 696 |
+
|
| 697 |
+
# In order to support downstream filesystems to be able to
|
| 698 |
+
# infer the compression from the original filename, like
|
| 699 |
+
# the `TarFileSystem`, let's extend the `io.BufferedReader`
|
| 700 |
+
# fileobject protocol by adding a dedicated attribute
|
| 701 |
+
# `original`.
|
| 702 |
+
f = open(fn, mode)
|
| 703 |
+
f.original = detail.get("original")
|
| 704 |
+
return f
|
| 705 |
+
else:
|
| 706 |
+
raise ValueError(
|
| 707 |
+
f"Attempt to open partially cached file {path}"
|
| 708 |
+
f" as a wholly cached file"
|
| 709 |
+
)
|
| 710 |
+
else:
|
| 711 |
+
fn = self._make_local_details(path)
|
| 712 |
+
kwargs["mode"] = mode
|
| 713 |
+
|
| 714 |
+
# call target filesystems open
|
| 715 |
+
self._mkcache()
|
| 716 |
+
if self.compression:
|
| 717 |
+
with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2:
|
| 718 |
+
if isinstance(f, AbstractBufferedFile):
|
| 719 |
+
# want no type of caching if just downloading whole thing
|
| 720 |
+
f.cache = BaseCache(0, f.cache.fetcher, f.size)
|
| 721 |
+
comp = (
|
| 722 |
+
infer_compression(path)
|
| 723 |
+
if self.compression == "infer"
|
| 724 |
+
else self.compression
|
| 725 |
+
)
|
| 726 |
+
f = compr[comp](f, mode="rb")
|
| 727 |
+
data = True
|
| 728 |
+
while data:
|
| 729 |
+
block = getattr(f, "blocksize", 5 * 2**20)
|
| 730 |
+
data = f.read(block)
|
| 731 |
+
f2.write(data)
|
| 732 |
+
else:
|
| 733 |
+
self.fs.get_file(path, fn)
|
| 734 |
+
self.save_cache()
|
| 735 |
+
return self._open(path, mode)
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
class SimpleCacheFileSystem(WholeFileCacheFileSystem):
|
| 739 |
+
"""Caches whole remote files on first access
|
| 740 |
+
|
| 741 |
+
This class is intended as a layer over any other file system, and
|
| 742 |
+
will make a local copy of each file accessed, so that all subsequent
|
| 743 |
+
reads are local. This implementation only copies whole files, and
|
| 744 |
+
does not keep any metadata about the download time or file details.
|
| 745 |
+
It is therefore safer to use in multi-threaded/concurrent situations.
|
| 746 |
+
|
| 747 |
+
This is the only of the caching filesystems that supports write: you will
|
| 748 |
+
be given a real local open file, and upon close and commit, it will be
|
| 749 |
+
uploaded to the target filesystem; the writability or the target URL is
|
| 750 |
+
not checked until that time.
|
| 751 |
+
|
| 752 |
+
"""
|
| 753 |
+
|
| 754 |
+
protocol = "simplecache"
|
| 755 |
+
local_file = True
|
| 756 |
+
transaction_type = WriteCachedTransaction
|
| 757 |
+
|
| 758 |
+
def __init__(self, **kwargs):
|
| 759 |
+
kw = kwargs.copy()
|
| 760 |
+
for key in ["cache_check", "expiry_time", "check_files"]:
|
| 761 |
+
kw[key] = False
|
| 762 |
+
super().__init__(**kw)
|
| 763 |
+
for storage in self.storage:
|
| 764 |
+
if not os.path.exists(storage):
|
| 765 |
+
os.makedirs(storage, exist_ok=True)
|
| 766 |
+
|
| 767 |
+
def _check_file(self, path):
|
| 768 |
+
self._check_cache()
|
| 769 |
+
sha = self._mapper(path)
|
| 770 |
+
for storage in self.storage:
|
| 771 |
+
fn = os.path.join(storage, sha)
|
| 772 |
+
if os.path.exists(fn):
|
| 773 |
+
return fn
|
| 774 |
+
|
| 775 |
+
def save_cache(self):
|
| 776 |
+
pass
|
| 777 |
+
|
| 778 |
+
def load_cache(self):
|
| 779 |
+
pass
|
| 780 |
+
|
| 781 |
+
def pipe_file(self, path, value=None, **kwargs):
|
| 782 |
+
if self._intrans:
|
| 783 |
+
with self.open(path, "wb") as f:
|
| 784 |
+
f.write(value)
|
| 785 |
+
else:
|
| 786 |
+
super().pipe_file(path, value)
|
| 787 |
+
|
| 788 |
+
def ls(self, path, detail=True, **kwargs):
|
| 789 |
+
path = self._strip_protocol(path)
|
| 790 |
+
details = []
|
| 791 |
+
try:
|
| 792 |
+
details = self.fs.ls(
|
| 793 |
+
path, detail=True, **kwargs
|
| 794 |
+
).copy() # don't edit original!
|
| 795 |
+
except FileNotFoundError as e:
|
| 796 |
+
ex = e
|
| 797 |
+
else:
|
| 798 |
+
ex = None
|
| 799 |
+
if self._intrans:
|
| 800 |
+
path1 = path.rstrip("/") + "/"
|
| 801 |
+
for f in self.transaction.files:
|
| 802 |
+
if f.path == path:
|
| 803 |
+
details.append(
|
| 804 |
+
{"name": path, "size": f.size or f.tell(), "type": "file"}
|
| 805 |
+
)
|
| 806 |
+
elif f.path.startswith(path1):
|
| 807 |
+
if f.path.count("/") == path1.count("/"):
|
| 808 |
+
details.append(
|
| 809 |
+
{"name": f.path, "size": f.size or f.tell(), "type": "file"}
|
| 810 |
+
)
|
| 811 |
+
else:
|
| 812 |
+
dname = "/".join(f.path.split("/")[: path1.count("/") + 1])
|
| 813 |
+
details.append({"name": dname, "size": 0, "type": "directory"})
|
| 814 |
+
if ex is not None and not details:
|
| 815 |
+
raise ex
|
| 816 |
+
if detail:
|
| 817 |
+
return details
|
| 818 |
+
return sorted(_["name"] for _ in details)
|
| 819 |
+
|
| 820 |
+
def info(self, path, **kwargs):
|
| 821 |
+
path = self._strip_protocol(path)
|
| 822 |
+
if self._intrans:
|
| 823 |
+
f = [_ for _ in self.transaction.files if _.path == path]
|
| 824 |
+
if f:
|
| 825 |
+
size = os.path.getsize(f[0].fn) if f[0].closed else f[0].tell()
|
| 826 |
+
return {"name": path, "size": size, "type": "file"}
|
| 827 |
+
f = any(_.path.startswith(path + "/") for _ in self.transaction.files)
|
| 828 |
+
if f:
|
| 829 |
+
return {"name": path, "size": 0, "type": "directory"}
|
| 830 |
+
return self.fs.info(path, **kwargs)
|
| 831 |
+
|
| 832 |
+
def pipe(self, path, value=None, **kwargs):
|
| 833 |
+
if isinstance(path, str):
|
| 834 |
+
self.pipe_file(self._strip_protocol(path), value, **kwargs)
|
| 835 |
+
elif isinstance(path, dict):
|
| 836 |
+
for k, v in path.items():
|
| 837 |
+
self.pipe_file(self._strip_protocol(k), v, **kwargs)
|
| 838 |
+
else:
|
| 839 |
+
raise ValueError("path must be str or dict")
|
| 840 |
+
|
| 841 |
+
async def _cat_file(self, path, start=None, end=None, **kwargs):
|
| 842 |
+
logger.debug("async cat_file %s", path)
|
| 843 |
+
path = self._strip_protocol(path)
|
| 844 |
+
sha = self._mapper(path)
|
| 845 |
+
fn = self._check_file(path)
|
| 846 |
+
|
| 847 |
+
if not fn:
|
| 848 |
+
fn = os.path.join(self.storage[-1], sha)
|
| 849 |
+
await self.fs._get_file(path, fn, **kwargs)
|
| 850 |
+
|
| 851 |
+
with open(fn, "rb") as f: # noqa ASYNC230
|
| 852 |
+
if start:
|
| 853 |
+
f.seek(start)
|
| 854 |
+
size = -1 if end is None else end - f.tell()
|
| 855 |
+
return f.read(size)
|
| 856 |
+
|
| 857 |
+
async def _cat_ranges(
|
| 858 |
+
self, paths, starts, ends, max_gap=None, on_error="return", **kwargs
|
| 859 |
+
):
|
| 860 |
+
logger.debug("async cat ranges %s", paths)
|
| 861 |
+
lpaths = []
|
| 862 |
+
rset = set()
|
| 863 |
+
download = []
|
| 864 |
+
rpaths = []
|
| 865 |
+
for p in paths:
|
| 866 |
+
fn = self._check_file(p)
|
| 867 |
+
if fn is None and p not in rset:
|
| 868 |
+
sha = self._mapper(p)
|
| 869 |
+
fn = os.path.join(self.storage[-1], sha)
|
| 870 |
+
download.append(fn)
|
| 871 |
+
rset.add(p)
|
| 872 |
+
rpaths.append(p)
|
| 873 |
+
lpaths.append(fn)
|
| 874 |
+
if download:
|
| 875 |
+
await self.fs._get(rpaths, download, on_error=on_error)
|
| 876 |
+
|
| 877 |
+
return LocalFileSystem().cat_ranges(
|
| 878 |
+
lpaths, starts, ends, max_gap=max_gap, on_error=on_error, **kwargs
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
def cat_ranges(
|
| 882 |
+
self, paths, starts, ends, max_gap=None, on_error="return", **kwargs
|
| 883 |
+
):
|
| 884 |
+
logger.debug("cat ranges %s", paths)
|
| 885 |
+
lpaths = [self._check_file(p) for p in paths]
|
| 886 |
+
rpaths = [p for l, p in zip(lpaths, paths) if l is False]
|
| 887 |
+
lpaths = [l for l, p in zip(lpaths, paths) if l is False]
|
| 888 |
+
self.fs.get(rpaths, lpaths)
|
| 889 |
+
paths = [self._check_file(p) for p in paths]
|
| 890 |
+
return LocalFileSystem().cat_ranges(
|
| 891 |
+
paths, starts, ends, max_gap=max_gap, on_error=on_error, **kwargs
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
def _open(self, path, mode="rb", **kwargs):
|
| 895 |
+
path = self._strip_protocol(path)
|
| 896 |
+
sha = self._mapper(path)
|
| 897 |
+
|
| 898 |
+
if "r" not in mode:
|
| 899 |
+
fn = os.path.join(self.storage[-1], sha)
|
| 900 |
+
user_specified_kwargs = {
|
| 901 |
+
k: v
|
| 902 |
+
for k, v in kwargs.items()
|
| 903 |
+
if k not in ["autocommit", "block_size", "cache_options"]
|
| 904 |
+
} # those were added by open()
|
| 905 |
+
return LocalTempFile(
|
| 906 |
+
self,
|
| 907 |
+
path,
|
| 908 |
+
mode=mode,
|
| 909 |
+
autocommit=not self._intrans,
|
| 910 |
+
fn=fn,
|
| 911 |
+
**user_specified_kwargs,
|
| 912 |
+
)
|
| 913 |
+
fn = self._check_file(path)
|
| 914 |
+
if fn:
|
| 915 |
+
return open(fn, mode)
|
| 916 |
+
|
| 917 |
+
fn = os.path.join(self.storage[-1], sha)
|
| 918 |
+
logger.debug("Copying %s to local cache", path)
|
| 919 |
+
kwargs["mode"] = mode
|
| 920 |
+
|
| 921 |
+
self._mkcache()
|
| 922 |
+
self._cache_size = None
|
| 923 |
+
if self.compression:
|
| 924 |
+
with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2:
|
| 925 |
+
if isinstance(f, AbstractBufferedFile):
|
| 926 |
+
# want no type of caching if just downloading whole thing
|
| 927 |
+
f.cache = BaseCache(0, f.cache.fetcher, f.size)
|
| 928 |
+
comp = (
|
| 929 |
+
infer_compression(path)
|
| 930 |
+
if self.compression == "infer"
|
| 931 |
+
else self.compression
|
| 932 |
+
)
|
| 933 |
+
f = compr[comp](f, mode="rb")
|
| 934 |
+
data = True
|
| 935 |
+
while data:
|
| 936 |
+
block = getattr(f, "blocksize", 5 * 2**20)
|
| 937 |
+
data = f.read(block)
|
| 938 |
+
f2.write(data)
|
| 939 |
+
else:
|
| 940 |
+
self.fs.get_file(path, fn)
|
| 941 |
+
return self._open(path, mode)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
class LocalTempFile:
|
| 945 |
+
"""A temporary local file, which will be uploaded on commit"""
|
| 946 |
+
|
| 947 |
+
def __init__(self, fs, path, fn, mode="wb", autocommit=True, seek=0, **kwargs):
|
| 948 |
+
self.fn = fn
|
| 949 |
+
self.fh = open(fn, mode)
|
| 950 |
+
self.mode = mode
|
| 951 |
+
if seek:
|
| 952 |
+
self.fh.seek(seek)
|
| 953 |
+
self.path = path
|
| 954 |
+
self.size = None
|
| 955 |
+
self.fs = fs
|
| 956 |
+
self.closed = False
|
| 957 |
+
self.autocommit = autocommit
|
| 958 |
+
self.kwargs = kwargs
|
| 959 |
+
|
| 960 |
+
def __reduce__(self):
|
| 961 |
+
# always open in r+b to allow continuing writing at a location
|
| 962 |
+
return (
|
| 963 |
+
LocalTempFile,
|
| 964 |
+
(self.fs, self.path, self.fn, "r+b", self.autocommit, self.tell()),
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
def __enter__(self):
|
| 968 |
+
return self.fh
|
| 969 |
+
|
| 970 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 971 |
+
self.close()
|
| 972 |
+
|
| 973 |
+
def close(self):
|
| 974 |
+
# self.size = self.fh.tell()
|
| 975 |
+
if self.closed:
|
| 976 |
+
return
|
| 977 |
+
self.fh.close()
|
| 978 |
+
self.closed = True
|
| 979 |
+
if self.autocommit:
|
| 980 |
+
self.commit()
|
| 981 |
+
|
| 982 |
+
def discard(self):
|
| 983 |
+
self.fh.close()
|
| 984 |
+
os.remove(self.fn)
|
| 985 |
+
|
| 986 |
+
def commit(self):
|
| 987 |
+
self.fs.put(self.fn, self.path, **self.kwargs)
|
| 988 |
+
# we do not delete the local copy, it's still in the cache.
|
| 989 |
+
|
| 990 |
+
@property
|
| 991 |
+
def name(self):
|
| 992 |
+
return self.fn
|
| 993 |
+
|
| 994 |
+
def __repr__(self) -> str:
|
| 995 |
+
return f"LocalTempFile: {self.path}"
|
| 996 |
+
|
| 997 |
+
def __getattr__(self, item):
|
| 998 |
+
return getattr(self.fh, item)
|
pythonProject/.venv/Lib/site-packages/fsspec/implementations/dask.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dask
|
| 2 |
+
from distributed.client import Client, _get_global_client
|
| 3 |
+
from distributed.worker import Worker
|
| 4 |
+
|
| 5 |
+
from fsspec import filesystem
|
| 6 |
+
from fsspec.spec import AbstractBufferedFile, AbstractFileSystem
|
| 7 |
+
from fsspec.utils import infer_storage_options
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _get_client(client):
|
| 11 |
+
if client is None:
|
| 12 |
+
return _get_global_client()
|
| 13 |
+
elif isinstance(client, Client):
|
| 14 |
+
return client
|
| 15 |
+
else:
|
| 16 |
+
# e.g., connection string
|
| 17 |
+
return Client(client)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _in_worker():
|
| 21 |
+
return bool(Worker._instances)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DaskWorkerFileSystem(AbstractFileSystem):
|
| 25 |
+
"""View files accessible to a worker as any other remote file-system
|
| 26 |
+
|
| 27 |
+
When instances are run on the worker, uses the real filesystem. When
|
| 28 |
+
run on the client, they call the worker to provide information or data.
|
| 29 |
+
|
| 30 |
+
**Warning** this implementation is experimental, and read-only for now.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self, target_protocol=None, target_options=None, fs=None, client=None, **kwargs
|
| 35 |
+
):
|
| 36 |
+
super().__init__(**kwargs)
|
| 37 |
+
if not (fs is None) ^ (target_protocol is None):
|
| 38 |
+
raise ValueError(
|
| 39 |
+
"Please provide one of filesystem instance (fs) or"
|
| 40 |
+
" target_protocol, not both"
|
| 41 |
+
)
|
| 42 |
+
self.target_protocol = target_protocol
|
| 43 |
+
self.target_options = target_options
|
| 44 |
+
self.worker = None
|
| 45 |
+
self.client = client
|
| 46 |
+
self.fs = fs
|
| 47 |
+
self._determine_worker()
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def _get_kwargs_from_urls(path):
|
| 51 |
+
so = infer_storage_options(path)
|
| 52 |
+
if "host" in so and "port" in so:
|
| 53 |
+
return {"client": f"{so['host']}:{so['port']}"}
|
| 54 |
+
else:
|
| 55 |
+
return {}
|
| 56 |
+
|
| 57 |
+
def _determine_worker(self):
|
| 58 |
+
if _in_worker():
|
| 59 |
+
self.worker = True
|
| 60 |
+
if self.fs is None:
|
| 61 |
+
self.fs = filesystem(
|
| 62 |
+
self.target_protocol, **(self.target_options or {})
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
self.worker = False
|
| 66 |
+
self.client = _get_client(self.client)
|
| 67 |
+
self.rfs = dask.delayed(self)
|
| 68 |
+
|
| 69 |
+
def mkdir(self, *args, **kwargs):
|
| 70 |
+
if self.worker:
|
| 71 |
+
self.fs.mkdir(*args, **kwargs)
|
| 72 |
+
else:
|
| 73 |
+
self.rfs.mkdir(*args, **kwargs).compute()
|
| 74 |
+
|
| 75 |
+
def rm(self, *args, **kwargs):
|
| 76 |
+
if self.worker:
|
| 77 |
+
self.fs.rm(*args, **kwargs)
|
| 78 |
+
else:
|
| 79 |
+
self.rfs.rm(*args, **kwargs).compute()
|
| 80 |
+
|
| 81 |
+
def copy(self, *args, **kwargs):
|
| 82 |
+
if self.worker:
|
| 83 |
+
self.fs.copy(*args, **kwargs)
|
| 84 |
+
else:
|
| 85 |
+
self.rfs.copy(*args, **kwargs).compute()
|
| 86 |
+
|
| 87 |
+
def mv(self, *args, **kwargs):
|
| 88 |
+
if self.worker:
|
| 89 |
+
self.fs.mv(*args, **kwargs)
|
| 90 |
+
else:
|
| 91 |
+
self.rfs.mv(*args, **kwargs).compute()
|
| 92 |
+
|
| 93 |
+
def ls(self, *args, **kwargs):
|
| 94 |
+
if self.worker:
|
| 95 |
+
return self.fs.ls(*args, **kwargs)
|
| 96 |
+
else:
|
| 97 |
+
return self.rfs.ls(*args, **kwargs).compute()
|
| 98 |
+
|
| 99 |
+
def _open(
|
| 100 |
+
self,
|
| 101 |
+
path,
|
| 102 |
+
mode="rb",
|
| 103 |
+
block_size=None,
|
| 104 |
+
autocommit=True,
|
| 105 |
+
cache_options=None,
|
| 106 |
+
**kwargs,
|
| 107 |
+
):
|
| 108 |
+
if self.worker:
|
| 109 |
+
return self.fs._open(
|
| 110 |
+
path,
|
| 111 |
+
mode=mode,
|
| 112 |
+
block_size=block_size,
|
| 113 |
+
autocommit=autocommit,
|
| 114 |
+
cache_options=cache_options,
|
| 115 |
+
**kwargs,
|
| 116 |
+
)
|
| 117 |
+
else:
|
| 118 |
+
return DaskFile(
|
| 119 |
+
fs=self,
|
| 120 |
+
path=path,
|
| 121 |
+
mode=mode,
|
| 122 |
+
block_size=block_size,
|
| 123 |
+
autocommit=autocommit,
|
| 124 |
+
cache_options=cache_options,
|
| 125 |
+
**kwargs,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def fetch_range(self, path, mode, start, end):
|
| 129 |
+
if self.worker:
|
| 130 |
+
with self._open(path, mode) as f:
|
| 131 |
+
f.seek(start)
|
| 132 |
+
return f.read(end - start)
|
| 133 |
+
else:
|
| 134 |
+
return self.rfs.fetch_range(path, mode, start, end).compute()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class DaskFile(AbstractBufferedFile):
|
| 138 |
+
def __init__(self, mode="rb", **kwargs):
|
| 139 |
+
if mode != "rb":
|
| 140 |
+
raise ValueError('Remote dask files can only be opened in "rb" mode')
|
| 141 |
+
super().__init__(**kwargs)
|
| 142 |
+
|
| 143 |
+
def _upload_chunk(self, final=False):
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
def _initiate_upload(self):
|
| 147 |
+
"""Create remote file/upload"""
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
def _fetch_range(self, start, end):
|
| 151 |
+
"""Get the specified set of bytes from remote"""
|
| 152 |
+
return self.fs.fetch_range(self.path, self.mode, start, end)
|
pythonProject/.venv/Lib/site-packages/fsspec/implementations/data.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import io
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from urllib.parse import unquote
|
| 5 |
+
|
| 6 |
+
from fsspec import AbstractFileSystem
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DataFileSystem(AbstractFileSystem):
|
| 10 |
+
"""A handy decoder for data-URLs
|
| 11 |
+
|
| 12 |
+
Example
|
| 13 |
+
-------
|
| 14 |
+
>>> with fsspec.open("data:,Hello%2C%20World%21") as f:
|
| 15 |
+
... print(f.read())
|
| 16 |
+
b"Hello, World!"
|
| 17 |
+
|
| 18 |
+
See https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
protocol = "data"
|
| 22 |
+
|
| 23 |
+
def __init__(self, **kwargs):
|
| 24 |
+
"""No parameters for this filesystem"""
|
| 25 |
+
super().__init__(**kwargs)
|
| 26 |
+
|
| 27 |
+
def cat_file(self, path, start=None, end=None, **kwargs):
|
| 28 |
+
pref, data = path.split(",", 1)
|
| 29 |
+
if pref.endswith("base64"):
|
| 30 |
+
return base64.b64decode(data)[start:end]
|
| 31 |
+
return unquote(data).encode()[start:end]
|
| 32 |
+
|
| 33 |
+
def info(self, path, **kwargs):
|
| 34 |
+
pref, name = path.split(",", 1)
|
| 35 |
+
data = self.cat_file(path)
|
| 36 |
+
mime = pref.split(":", 1)[1].split(";", 1)[0]
|
| 37 |
+
return {"name": name, "size": len(data), "type": "file", "mimetype": mime}
|
| 38 |
+
|
| 39 |
+
def _open(
|
| 40 |
+
self,
|
| 41 |
+
path,
|
| 42 |
+
mode="rb",
|
| 43 |
+
block_size=None,
|
| 44 |
+
autocommit=True,
|
| 45 |
+
cache_options=None,
|
| 46 |
+
**kwargs,
|
| 47 |
+
):
|
| 48 |
+
if "r" not in mode:
|
| 49 |
+
raise ValueError("Read only filesystem")
|
| 50 |
+
return io.BytesIO(self.cat_file(path))
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def encode(data: bytes, mime: Optional[str] = None):
|
| 54 |
+
"""Format the given data into data-URL syntax
|
| 55 |
+
|
| 56 |
+
This version always base64 encodes, even when the data is ascii/url-safe.
|
| 57 |
+
"""
|
| 58 |
+
return f"data:{mime or ''};base64,{base64.b64encode(data).decode()}"
|
pythonProject/.venv/Lib/site-packages/fsspec/implementations/dbfs.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import urllib
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
from requests.adapters import HTTPAdapter, Retry
|
| 8 |
+
from typing_extensions import override
|
| 9 |
+
|
| 10 |
+
from fsspec import AbstractFileSystem
|
| 11 |
+
from fsspec.spec import AbstractBufferedFile
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DatabricksException(Exception):
|
| 15 |
+
"""
|
| 16 |
+
Helper class for exceptions raised in this module.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, error_code, message, details=None):
|
| 20 |
+
"""Create a new DatabricksException"""
|
| 21 |
+
super().__init__(message)
|
| 22 |
+
|
| 23 |
+
self.error_code = error_code
|
| 24 |
+
self.message = message
|
| 25 |
+
self.details = details
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DatabricksFileSystem(AbstractFileSystem):
|
| 29 |
+
"""
|
| 30 |
+
Get access to the Databricks filesystem implementation over HTTP.
|
| 31 |
+
Can be used inside and outside of a databricks cluster.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, instance, token, **kwargs):
|
| 35 |
+
"""
|
| 36 |
+
Create a new DatabricksFileSystem.
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
instance: str
|
| 41 |
+
The instance URL of the databricks cluster.
|
| 42 |
+
For example for an Azure databricks cluster, this
|
| 43 |
+
has the form adb-<some-number>.<two digits>.azuredatabricks.net.
|
| 44 |
+
token: str
|
| 45 |
+
Your personal token. Find out more
|
| 46 |
+
here: https://docs.databricks.com/dev-tools/api/latest/authentication.html
|
| 47 |
+
"""
|
| 48 |
+
self.instance = instance
|
| 49 |
+
self.token = token
|
| 50 |
+
self.session = requests.Session()
|
| 51 |
+
self.retries = Retry(
|
| 52 |
+
total=10,
|
| 53 |
+
backoff_factor=0.05,
|
| 54 |
+
status_forcelist=[408, 429, 500, 502, 503, 504],
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.session.mount("https://", HTTPAdapter(max_retries=self.retries))
|
| 58 |
+
self.session.headers.update({"Authorization": f"Bearer {self.token}"})
|
| 59 |
+
|
| 60 |
+
super().__init__(**kwargs)
|
| 61 |
+
|
| 62 |
+
@override
|
| 63 |
+
def _ls_from_cache(self, path) -> list[dict[str, str | int]] | None:
|
| 64 |
+
"""Check cache for listing
|
| 65 |
+
|
| 66 |
+
Returns listing, if found (may be empty list for a directory that
|
| 67 |
+
exists but contains nothing), None if not in cache.
|
| 68 |
+
"""
|
| 69 |
+
self.dircache.pop(path.rstrip("/"), None)
|
| 70 |
+
|
| 71 |
+
parent = self._parent(path)
|
| 72 |
+
if parent in self.dircache:
|
| 73 |
+
for entry in self.dircache[parent]:
|
| 74 |
+
if entry["name"] == path.rstrip("/"):
|
| 75 |
+
if entry["type"] != "directory":
|
| 76 |
+
return [entry]
|
| 77 |
+
return []
|
| 78 |
+
raise FileNotFoundError(path)
|
| 79 |
+
|
| 80 |
+
def ls(self, path, detail=True, **kwargs):
|
| 81 |
+
"""
|
| 82 |
+
List the contents of the given path.
|
| 83 |
+
|
| 84 |
+
Parameters
|
| 85 |
+
----------
|
| 86 |
+
path: str
|
| 87 |
+
Absolute path
|
| 88 |
+
detail: bool
|
| 89 |
+
Return not only the list of filenames,
|
| 90 |
+
but also additional information on file sizes
|
| 91 |
+
and types.
|
| 92 |
+
"""
|
| 93 |
+
try:
|
| 94 |
+
out = self._ls_from_cache(path)
|
| 95 |
+
except FileNotFoundError:
|
| 96 |
+
# This happens if the `path`'s parent was cached, but `path` is not
|
| 97 |
+
# there. This suggests that `path` is new since the parent was
|
| 98 |
+
# cached. Attempt to invalidate parent's cache before continuing.
|
| 99 |
+
self.dircache.pop(self._parent(path), None)
|
| 100 |
+
out = None
|
| 101 |
+
|
| 102 |
+
if not out:
|
| 103 |
+
try:
|
| 104 |
+
r = self._send_to_api(
|
| 105 |
+
method="get", endpoint="list", json={"path": path}
|
| 106 |
+
)
|
| 107 |
+
except DatabricksException as e:
|
| 108 |
+
if e.error_code == "RESOURCE_DOES_NOT_EXIST":
|
| 109 |
+
raise FileNotFoundError(e.message) from e
|
| 110 |
+
|
| 111 |
+
raise
|
| 112 |
+
files = r.get("files", [])
|
| 113 |
+
out = [
|
| 114 |
+
{
|
| 115 |
+
"name": o["path"],
|
| 116 |
+
"type": "directory" if o["is_dir"] else "file",
|
| 117 |
+
"size": o["file_size"],
|
| 118 |
+
}
|
| 119 |
+
for o in files
|
| 120 |
+
]
|
| 121 |
+
self.dircache[path] = out
|
| 122 |
+
|
| 123 |
+
if detail:
|
| 124 |
+
return out
|
| 125 |
+
return [o["name"] for o in out]
|
| 126 |
+
|
| 127 |
+
def makedirs(self, path, exist_ok=True):
|
| 128 |
+
"""
|
| 129 |
+
Create a given absolute path and all of its parents.
|
| 130 |
+
|
| 131 |
+
Parameters
|
| 132 |
+
----------
|
| 133 |
+
path: str
|
| 134 |
+
Absolute path to create
|
| 135 |
+
exist_ok: bool
|
| 136 |
+
If false, checks if the folder
|
| 137 |
+
exists before creating it (and raises an
|
| 138 |
+
Exception if this is the case)
|
| 139 |
+
"""
|
| 140 |
+
if not exist_ok:
|
| 141 |
+
try:
|
| 142 |
+
# If the following succeeds, the path is already present
|
| 143 |
+
self._send_to_api(
|
| 144 |
+
method="get", endpoint="get-status", json={"path": path}
|
| 145 |
+
)
|
| 146 |
+
raise FileExistsError(f"Path {path} already exists")
|
| 147 |
+
except DatabricksException as e:
|
| 148 |
+
if e.error_code == "RESOURCE_DOES_NOT_EXIST":
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
self._send_to_api(method="post", endpoint="mkdirs", json={"path": path})
|
| 153 |
+
except DatabricksException as e:
|
| 154 |
+
if e.error_code == "RESOURCE_ALREADY_EXISTS":
|
| 155 |
+
raise FileExistsError(e.message) from e
|
| 156 |
+
|
| 157 |
+
raise
|
| 158 |
+
self.invalidate_cache(self._parent(path))
|
| 159 |
+
|
| 160 |
+
def mkdir(self, path, create_parents=True, **kwargs):
|
| 161 |
+
"""
|
| 162 |
+
Create a given absolute path and all of its parents.
|
| 163 |
+
|
| 164 |
+
Parameters
|
| 165 |
+
----------
|
| 166 |
+
path: str
|
| 167 |
+
Absolute path to create
|
| 168 |
+
create_parents: bool
|
| 169 |
+
Whether to create all parents or not.
|
| 170 |
+
"False" is not implemented so far.
|
| 171 |
+
"""
|
| 172 |
+
if not create_parents:
|
| 173 |
+
raise NotImplementedError
|
| 174 |
+
|
| 175 |
+
self.mkdirs(path, **kwargs)
|
| 176 |
+
|
| 177 |
+
def rm(self, path, recursive=False, **kwargs):
|
| 178 |
+
"""
|
| 179 |
+
Remove the file or folder at the given absolute path.
|
| 180 |
+
|
| 181 |
+
Parameters
|
| 182 |
+
----------
|
| 183 |
+
path: str
|
| 184 |
+
Absolute path what to remove
|
| 185 |
+
recursive: bool
|
| 186 |
+
Recursively delete all files in a folder.
|
| 187 |
+
"""
|
| 188 |
+
try:
|
| 189 |
+
self._send_to_api(
|
| 190 |
+
method="post",
|
| 191 |
+
endpoint="delete",
|
| 192 |
+
json={"path": path, "recursive": recursive},
|
| 193 |
+
)
|
| 194 |
+
except DatabricksException as e:
|
| 195 |
+
# This is not really an exception, it just means
|
| 196 |
+
# not everything was deleted so far
|
| 197 |
+
if e.error_code == "PARTIAL_DELETE":
|
| 198 |
+
self.rm(path=path, recursive=recursive)
|
| 199 |
+
elif e.error_code == "IO_ERROR":
|
| 200 |
+
# Using the same exception as the os module would use here
|
| 201 |
+
raise OSError(e.message) from e
|
| 202 |
+
|
| 203 |
+
raise
|
| 204 |
+
self.invalidate_cache(self._parent(path))
|
| 205 |
+
|
| 206 |
+
def mv(
|
| 207 |
+
self, source_path, destination_path, recursive=False, maxdepth=None, **kwargs
|
| 208 |
+
):
|
| 209 |
+
"""
|
| 210 |
+
Move a source to a destination path.
|
| 211 |
+
|
| 212 |
+
A note from the original [databricks API manual]
|
| 213 |
+
(https://docs.databricks.com/dev-tools/api/latest/dbfs.html#move).
|
| 214 |
+
|
| 215 |
+
When moving a large number of files the API call will time out after
|
| 216 |
+
approximately 60s, potentially resulting in partially moved data.
|
| 217 |
+
Therefore, for operations that move more than 10k files, we strongly
|
| 218 |
+
discourage using the DBFS REST API.
|
| 219 |
+
|
| 220 |
+
Parameters
|
| 221 |
+
----------
|
| 222 |
+
source_path: str
|
| 223 |
+
From where to move (absolute path)
|
| 224 |
+
destination_path: str
|
| 225 |
+
To where to move (absolute path)
|
| 226 |
+
recursive: bool
|
| 227 |
+
Not implemented to far.
|
| 228 |
+
maxdepth:
|
| 229 |
+
Not implemented to far.
|
| 230 |
+
"""
|
| 231 |
+
if recursive:
|
| 232 |
+
raise NotImplementedError
|
| 233 |
+
if maxdepth:
|
| 234 |
+
raise NotImplementedError
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
self._send_to_api(
|
| 238 |
+
method="post",
|
| 239 |
+
endpoint="move",
|
| 240 |
+
json={"source_path": source_path, "destination_path": destination_path},
|
| 241 |
+
)
|
| 242 |
+
except DatabricksException as e:
|
| 243 |
+
if e.error_code == "RESOURCE_DOES_NOT_EXIST":
|
| 244 |
+
raise FileNotFoundError(e.message) from e
|
| 245 |
+
elif e.error_code == "RESOURCE_ALREADY_EXISTS":
|
| 246 |
+
raise FileExistsError(e.message) from e
|
| 247 |
+
|
| 248 |
+
raise
|
| 249 |
+
self.invalidate_cache(self._parent(source_path))
|
| 250 |
+
self.invalidate_cache(self._parent(destination_path))
|
| 251 |
+
|
| 252 |
+
def _open(self, path, mode="rb", block_size="default", **kwargs):
|
| 253 |
+
"""
|
| 254 |
+
Overwrite the base class method to make sure to create a DBFile.
|
| 255 |
+
All arguments are copied from the base method.
|
| 256 |
+
|
| 257 |
+
Only the default blocksize is allowed.
|
| 258 |
+
"""
|
| 259 |
+
return DatabricksFile(self, path, mode=mode, block_size=block_size, **kwargs)
|
| 260 |
+
|
| 261 |
+
def _send_to_api(self, method, endpoint, json):
|
| 262 |
+
"""
|
| 263 |
+
Send the given json to the DBFS API
|
| 264 |
+
using a get or post request (specified by the argument `method`).
|
| 265 |
+
|
| 266 |
+
Parameters
|
| 267 |
+
----------
|
| 268 |
+
method: str
|
| 269 |
+
Which http method to use for communication; "get" or "post".
|
| 270 |
+
endpoint: str
|
| 271 |
+
Where to send the request to (last part of the API URL)
|
| 272 |
+
json: dict
|
| 273 |
+
Dictionary of information to send
|
| 274 |
+
"""
|
| 275 |
+
if method == "post":
|
| 276 |
+
session_call = self.session.post
|
| 277 |
+
elif method == "get":
|
| 278 |
+
session_call = self.session.get
|
| 279 |
+
else:
|
| 280 |
+
raise ValueError(f"Do not understand method {method}")
|
| 281 |
+
|
| 282 |
+
url = urllib.parse.urljoin(f"https://{self.instance}/api/2.0/dbfs/", endpoint)
|
| 283 |
+
|
| 284 |
+
r = session_call(url, json=json)
|
| 285 |
+
|
| 286 |
+
# The DBFS API will return a json, also in case of an exception.
|
| 287 |
+
# We want to preserve this information as good as possible.
|
| 288 |
+
try:
|
| 289 |
+
r.raise_for_status()
|
| 290 |
+
except requests.HTTPError as e:
|
| 291 |
+
# try to extract json error message
|
| 292 |
+
# if that fails, fall back to the original exception
|
| 293 |
+
try:
|
| 294 |
+
exception_json = e.response.json()
|
| 295 |
+
except Exception:
|
| 296 |
+
raise e from None
|
| 297 |
+
|
| 298 |
+
raise DatabricksException(**exception_json) from e
|
| 299 |
+
|
| 300 |
+
return r.json()
|
| 301 |
+
|
| 302 |
+
def _create_handle(self, path, overwrite=True):
|
| 303 |
+
"""
|
| 304 |
+
Internal function to create a handle, which can be used to
|
| 305 |
+
write blocks of a file to DBFS.
|
| 306 |
+
A handle has a unique identifier which needs to be passed
|
| 307 |
+
whenever written during this transaction.
|
| 308 |
+
The handle is active for 10 minutes - after that a new
|
| 309 |
+
write transaction needs to be created.
|
| 310 |
+
Make sure to close the handle after you are finished.
|
| 311 |
+
|
| 312 |
+
Parameters
|
| 313 |
+
----------
|
| 314 |
+
path: str
|
| 315 |
+
Absolute path for this file.
|
| 316 |
+
overwrite: bool
|
| 317 |
+
If a file already exist at this location, either overwrite
|
| 318 |
+
it or raise an exception.
|
| 319 |
+
"""
|
| 320 |
+
try:
|
| 321 |
+
r = self._send_to_api(
|
| 322 |
+
method="post",
|
| 323 |
+
endpoint="create",
|
| 324 |
+
json={"path": path, "overwrite": overwrite},
|
| 325 |
+
)
|
| 326 |
+
return r["handle"]
|
| 327 |
+
except DatabricksException as e:
|
| 328 |
+
if e.error_code == "RESOURCE_ALREADY_EXISTS":
|
| 329 |
+
raise FileExistsError(e.message) from e
|
| 330 |
+
|
| 331 |
+
raise
|
| 332 |
+
|
| 333 |
+
def _close_handle(self, handle):
|
| 334 |
+
"""
|
| 335 |
+
Close a handle, which was opened by :func:`_create_handle`.
|
| 336 |
+
|
| 337 |
+
Parameters
|
| 338 |
+
----------
|
| 339 |
+
handle: str
|
| 340 |
+
Which handle to close.
|
| 341 |
+
"""
|
| 342 |
+
try:
|
| 343 |
+
self._send_to_api(method="post", endpoint="close", json={"handle": handle})
|
| 344 |
+
except DatabricksException as e:
|
| 345 |
+
if e.error_code == "RESOURCE_DOES_NOT_EXIST":
|
| 346 |
+
raise FileNotFoundError(e.message) from e
|
| 347 |
+
|
| 348 |
+
raise
|
| 349 |
+
|
| 350 |
+
def _add_data(self, handle, data):
|
| 351 |
+
"""
|
| 352 |
+
Upload data to an already opened file handle
|
| 353 |
+
(opened by :func:`_create_handle`).
|
| 354 |
+
The maximal allowed data size is 1MB after
|
| 355 |
+
conversion to base64.
|
| 356 |
+
Remember to close the handle when you are finished.
|
| 357 |
+
|
| 358 |
+
Parameters
|
| 359 |
+
----------
|
| 360 |
+
handle: str
|
| 361 |
+
Which handle to upload data to.
|
| 362 |
+
data: bytes
|
| 363 |
+
Block of data to add to the handle.
|
| 364 |
+
"""
|
| 365 |
+
data = base64.b64encode(data).decode()
|
| 366 |
+
try:
|
| 367 |
+
self._send_to_api(
|
| 368 |
+
method="post",
|
| 369 |
+
endpoint="add-block",
|
| 370 |
+
json={"handle": handle, "data": data},
|
| 371 |
+
)
|
| 372 |
+
except DatabricksException as e:
|
| 373 |
+
if e.error_code == "RESOURCE_DOES_NOT_EXIST":
|
| 374 |
+
raise FileNotFoundError(e.message) from e
|
| 375 |
+
elif e.error_code == "MAX_BLOCK_SIZE_EXCEEDED":
|
| 376 |
+
raise ValueError(e.message) from e
|
| 377 |
+
|
| 378 |
+
raise
|
| 379 |
+
|
| 380 |
+
def _get_data(self, path, start, end):
|
| 381 |
+
"""
|
| 382 |
+
Download data in bytes from a given absolute path in a block
|
| 383 |
+
from [start, start+length].
|
| 384 |
+
The maximum number of allowed bytes to read is 1MB.
|
| 385 |
+
|
| 386 |
+
Parameters
|
| 387 |
+
----------
|
| 388 |
+
path: str
|
| 389 |
+
Absolute path to download data from
|
| 390 |
+
start: int
|
| 391 |
+
Start position of the block
|
| 392 |
+
end: int
|
| 393 |
+
End position of the block
|
| 394 |
+
"""
|
| 395 |
+
try:
|
| 396 |
+
r = self._send_to_api(
|
| 397 |
+
method="get",
|
| 398 |
+
endpoint="read",
|
| 399 |
+
json={"path": path, "offset": start, "length": end - start},
|
| 400 |
+
)
|
| 401 |
+
return base64.b64decode(r["data"])
|
| 402 |
+
except DatabricksException as e:
|
| 403 |
+
if e.error_code == "RESOURCE_DOES_NOT_EXIST":
|
| 404 |
+
raise FileNotFoundError(e.message) from e
|
| 405 |
+
elif e.error_code in ["INVALID_PARAMETER_VALUE", "MAX_READ_SIZE_EXCEEDED"]:
|
| 406 |
+
raise ValueError(e.message) from e
|
| 407 |
+
|
| 408 |
+
raise
|
| 409 |
+
|
| 410 |
+
def invalidate_cache(self, path=None):
|
| 411 |
+
if path is None:
|
| 412 |
+
self.dircache.clear()
|
| 413 |
+
else:
|
| 414 |
+
self.dircache.pop(path, None)
|
| 415 |
+
super().invalidate_cache(path)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class DatabricksFile(AbstractBufferedFile):
|
| 419 |
+
"""
|
| 420 |
+
Helper class for files referenced in the DatabricksFileSystem.
|
| 421 |
+
"""
|
| 422 |
+
|
| 423 |
+
DEFAULT_BLOCK_SIZE = 1 * 2**20 # only allowed block size
|
| 424 |
+
|
| 425 |
+
def __init__(
|
| 426 |
+
self,
|
| 427 |
+
fs,
|
| 428 |
+
path,
|
| 429 |
+
mode="rb",
|
| 430 |
+
block_size="default",
|
| 431 |
+
autocommit=True,
|
| 432 |
+
cache_type="readahead",
|
| 433 |
+
cache_options=None,
|
| 434 |
+
**kwargs,
|
| 435 |
+
):
|
| 436 |
+
"""
|
| 437 |
+
Create a new instance of the DatabricksFile.
|
| 438 |
+
|
| 439 |
+
The blocksize needs to be the default one.
|
| 440 |
+
"""
|
| 441 |
+
if block_size is None or block_size == "default":
|
| 442 |
+
block_size = self.DEFAULT_BLOCK_SIZE
|
| 443 |
+
|
| 444 |
+
assert block_size == self.DEFAULT_BLOCK_SIZE, (
|
| 445 |
+
f"Only the default block size is allowed, not {block_size}"
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
super().__init__(
|
| 449 |
+
fs,
|
| 450 |
+
path,
|
| 451 |
+
mode=mode,
|
| 452 |
+
block_size=block_size,
|
| 453 |
+
autocommit=autocommit,
|
| 454 |
+
cache_type=cache_type,
|
| 455 |
+
cache_options=cache_options or {},
|
| 456 |
+
**kwargs,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
def _initiate_upload(self):
|
| 460 |
+
"""Internal function to start a file upload"""
|
| 461 |
+
self.handle = self.fs._create_handle(self.path)
|
| 462 |
+
|
| 463 |
+
def _upload_chunk(self, final=False):
|
| 464 |
+
"""Internal function to add a chunk of data to a started upload"""
|
| 465 |
+
self.buffer.seek(0)
|
| 466 |
+
data = self.buffer.getvalue()
|
| 467 |
+
|
| 468 |
+
data_chunks = [
|
| 469 |
+
data[start:end] for start, end in self._to_sized_blocks(len(data))
|
| 470 |
+
]
|
| 471 |
+
|
| 472 |
+
for data_chunk in data_chunks:
|
| 473 |
+
self.fs._add_data(handle=self.handle, data=data_chunk)
|
| 474 |
+
|
| 475 |
+
if final:
|
| 476 |
+
self.fs._close_handle(handle=self.handle)
|
| 477 |
+
return True
|
| 478 |
+
|
| 479 |
+
def _fetch_range(self, start, end):
|
| 480 |
+
"""Internal function to download a block of data"""
|
| 481 |
+
return_buffer = b""
|
| 482 |
+
length = end - start
|
| 483 |
+
for chunk_start, chunk_end in self._to_sized_blocks(length, start):
|
| 484 |
+
return_buffer += self.fs._get_data(
|
| 485 |
+
path=self.path, start=chunk_start, end=chunk_end
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
return return_buffer
|
| 489 |
+
|
| 490 |
+
def _to_sized_blocks(self, length, start=0):
|
| 491 |
+
"""Helper function to split a range from 0 to total_length into blocksizes"""
|
| 492 |
+
end = start + length
|
| 493 |
+
for data_chunk in range(start, end, self.blocksize):
|
| 494 |
+
data_start = data_chunk
|
| 495 |
+
data_end = min(end, data_chunk + self.blocksize)
|
| 496 |
+
yield data_start, data_end
|
pythonProject/.venv/Lib/site-packages/fsspec/implementations/dirfs.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .. import filesystem
|
| 2 |
+
from ..asyn import AsyncFileSystem
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class DirFileSystem(AsyncFileSystem):
|
| 6 |
+
"""Directory prefix filesystem
|
| 7 |
+
|
| 8 |
+
The DirFileSystem is a filesystem-wrapper. It assumes every path it is dealing with
|
| 9 |
+
is relative to the `path`. After performing the necessary paths operation it
|
| 10 |
+
delegates everything to the wrapped filesystem.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
protocol = "dir"
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
path=None,
|
| 18 |
+
fs=None,
|
| 19 |
+
fo=None,
|
| 20 |
+
target_protocol=None,
|
| 21 |
+
target_options=None,
|
| 22 |
+
**storage_options,
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Parameters
|
| 26 |
+
----------
|
| 27 |
+
path: str
|
| 28 |
+
Path to the directory.
|
| 29 |
+
fs: AbstractFileSystem
|
| 30 |
+
An instantiated filesystem to wrap.
|
| 31 |
+
target_protocol, target_options:
|
| 32 |
+
if fs is none, construct it from these
|
| 33 |
+
fo: str
|
| 34 |
+
Alternate for path; do not provide both
|
| 35 |
+
"""
|
| 36 |
+
super().__init__(**storage_options)
|
| 37 |
+
if fs is None:
|
| 38 |
+
fs = filesystem(protocol=target_protocol, **(target_options or {}))
|
| 39 |
+
path = path or fo
|
| 40 |
+
|
| 41 |
+
if self.asynchronous and not fs.async_impl:
|
| 42 |
+
raise ValueError("can't use asynchronous with non-async fs")
|
| 43 |
+
|
| 44 |
+
if fs.async_impl and self.asynchronous != fs.asynchronous:
|
| 45 |
+
raise ValueError("both dirfs and fs should be in the same sync/async mode")
|
| 46 |
+
|
| 47 |
+
self.path = fs._strip_protocol(path)
|
| 48 |
+
self.fs = fs
|
| 49 |
+
|
| 50 |
+
def _join(self, path):
|
| 51 |
+
if isinstance(path, str):
|
| 52 |
+
if not self.path:
|
| 53 |
+
return path
|
| 54 |
+
if not path:
|
| 55 |
+
return self.path
|
| 56 |
+
return self.fs.sep.join((self.path, self._strip_protocol(path)))
|
| 57 |
+
if isinstance(path, dict):
|
| 58 |
+
return {self._join(_path): value for _path, value in path.items()}
|
| 59 |
+
return [self._join(_path) for _path in path]
|
| 60 |
+
|
| 61 |
+
def _relpath(self, path):
|
| 62 |
+
if isinstance(path, str):
|
| 63 |
+
if not self.path:
|
| 64 |
+
return path
|
| 65 |
+
# We need to account for S3FileSystem returning paths that do not
|
| 66 |
+
# start with a '/'
|
| 67 |
+
if path == self.path or (
|
| 68 |
+
self.path.startswith(self.fs.sep) and path == self.path[1:]
|
| 69 |
+
):
|
| 70 |
+
return ""
|
| 71 |
+
prefix = self.path + self.fs.sep
|
| 72 |
+
if self.path.startswith(self.fs.sep) and not path.startswith(self.fs.sep):
|
| 73 |
+
prefix = prefix[1:]
|
| 74 |
+
assert path.startswith(prefix)
|
| 75 |
+
return path[len(prefix) :]
|
| 76 |
+
return [self._relpath(_path) for _path in path]
|
| 77 |
+
|
| 78 |
+
# Wrappers below
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def sep(self):
|
| 82 |
+
return self.fs.sep
|
| 83 |
+
|
| 84 |
+
async def set_session(self, *args, **kwargs):
|
| 85 |
+
return await self.fs.set_session(*args, **kwargs)
|
| 86 |
+
|
| 87 |
+
async def _rm_file(self, path, **kwargs):
|
| 88 |
+
return await self.fs._rm_file(self._join(path), **kwargs)
|
| 89 |
+
|
| 90 |
+
def rm_file(self, path, **kwargs):
|
| 91 |
+
return self.fs.rm_file(self._join(path), **kwargs)
|
| 92 |
+
|
| 93 |
+
async def _rm(self, path, *args, **kwargs):
|
| 94 |
+
return await self.fs._rm(self._join(path), *args, **kwargs)
|
| 95 |
+
|
| 96 |
+
def rm(self, path, *args, **kwargs):
|
| 97 |
+
return self.fs.rm(self._join(path), *args, **kwargs)
|
| 98 |
+
|
| 99 |
+
async def _cp_file(self, path1, path2, **kwargs):
|
| 100 |
+
return await self.fs._cp_file(self._join(path1), self._join(path2), **kwargs)
|
| 101 |
+
|
| 102 |
+
def cp_file(self, path1, path2, **kwargs):
|
| 103 |
+
return self.fs.cp_file(self._join(path1), self._join(path2), **kwargs)
|
| 104 |
+
|
| 105 |
+
async def _copy(
|
| 106 |
+
self,
|
| 107 |
+
path1,
|
| 108 |
+
path2,
|
| 109 |
+
*args,
|
| 110 |
+
**kwargs,
|
| 111 |
+
):
|
| 112 |
+
return await self.fs._copy(
|
| 113 |
+
self._join(path1),
|
| 114 |
+
self._join(path2),
|
| 115 |
+
*args,
|
| 116 |
+
**kwargs,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def copy(self, path1, path2, *args, **kwargs):
|
| 120 |
+
return self.fs.copy(
|
| 121 |
+
self._join(path1),
|
| 122 |
+
self._join(path2),
|
| 123 |
+
*args,
|
| 124 |
+
**kwargs,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
async def _pipe(self, path, *args, **kwargs):
|
| 128 |
+
return await self.fs._pipe(self._join(path), *args, **kwargs)
|
| 129 |
+
|
| 130 |
+
def pipe(self, path, *args, **kwargs):
|
| 131 |
+
return self.fs.pipe(self._join(path), *args, **kwargs)
|
| 132 |
+
|
| 133 |
+
async def _pipe_file(self, path, *args, **kwargs):
|
| 134 |
+
return await self.fs._pipe_file(self._join(path), *args, **kwargs)
|
| 135 |
+
|
| 136 |
+
def pipe_file(self, path, *args, **kwargs):
|
| 137 |
+
return self.fs.pipe_file(self._join(path), *args, **kwargs)
|
| 138 |
+
|
| 139 |
+
async def _cat_file(self, path, *args, **kwargs):
|
| 140 |
+
return await self.fs._cat_file(self._join(path), *args, **kwargs)
|
| 141 |
+
|
| 142 |
+
def cat_file(self, path, *args, **kwargs):
|
| 143 |
+
return self.fs.cat_file(self._join(path), *args, **kwargs)
|
| 144 |
+
|
| 145 |
+
async def _cat(self, path, *args, **kwargs):
|
| 146 |
+
ret = await self.fs._cat(
|
| 147 |
+
self._join(path),
|
| 148 |
+
*args,
|
| 149 |
+
**kwargs,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if isinstance(ret, dict):
|
| 153 |
+
return {self._relpath(key): value for key, value in ret.items()}
|
| 154 |
+
|
| 155 |
+
return ret
|
| 156 |
+
|
| 157 |
+
def cat(self, path, *args, **kwargs):
|
| 158 |
+
ret = self.fs.cat(
|
| 159 |
+
self._join(path),
|
| 160 |
+
*args,
|
| 161 |
+
**kwargs,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if isinstance(ret, dict):
|
| 165 |
+
return {self._relpath(key): value for key, value in ret.items()}
|
| 166 |
+
|
| 167 |
+
return ret
|
| 168 |
+
|
| 169 |
+
async def _put_file(self, lpath, rpath, **kwargs):
|
| 170 |
+
return await self.fs._put_file(lpath, self._join(rpath), **kwargs)
|
| 171 |
+
|
| 172 |
+
def put_file(self, lpath, rpath, **kwargs):
|
| 173 |
+
return self.fs.put_file(lpath, self._join(rpath), **kwargs)
|
| 174 |
+
|
| 175 |
+
async def _put(
|
| 176 |
+
self,
|
| 177 |
+
lpath,
|
| 178 |
+
rpath,
|
| 179 |
+
*args,
|
| 180 |
+
**kwargs,
|
| 181 |
+
):
|
| 182 |
+
return await self.fs._put(
|
| 183 |
+
lpath,
|
| 184 |
+
self._join(rpath),
|
| 185 |
+
*args,
|
| 186 |
+
**kwargs,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def put(self, lpath, rpath, *args, **kwargs):
|
| 190 |
+
return self.fs.put(
|
| 191 |
+
lpath,
|
| 192 |
+
self._join(rpath),
|
| 193 |
+
*args,
|
| 194 |
+
**kwargs,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
async def _get_file(self, rpath, lpath, **kwargs):
|
| 198 |
+
return await self.fs._get_file(self._join(rpath), lpath, **kwargs)
|
| 199 |
+
|
| 200 |
+
def get_file(self, rpath, lpath, **kwargs):
|
| 201 |
+
return self.fs.get_file(self._join(rpath), lpath, **kwargs)
|
| 202 |
+
|
| 203 |
+
async def _get(self, rpath, *args, **kwargs):
|
| 204 |
+
return await self.fs._get(self._join(rpath), *args, **kwargs)
|
| 205 |
+
|
| 206 |
+
def get(self, rpath, *args, **kwargs):
|
| 207 |
+
return self.fs.get(self._join(rpath), *args, **kwargs)
|
| 208 |
+
|
| 209 |
+
async def _isfile(self, path):
|
| 210 |
+
return await self.fs._isfile(self._join(path))
|
| 211 |
+
|
| 212 |
+
def isfile(self, path):
|
| 213 |
+
return self.fs.isfile(self._join(path))
|
| 214 |
+
|
| 215 |
+
async def _isdir(self, path):
|
| 216 |
+
return await self.fs._isdir(self._join(path))
|
| 217 |
+
|
| 218 |
+
def isdir(self, path):
|
| 219 |
+
return self.fs.isdir(self._join(path))
|
| 220 |
+
|
| 221 |
+
async def _size(self, path):
|
| 222 |
+
return await self.fs._size(self._join(path))
|
| 223 |
+
|
| 224 |
+
def size(self, path):
|
| 225 |
+
return self.fs.size(self._join(path))
|
| 226 |
+
|
| 227 |
+
async def _exists(self, path):
|
| 228 |
+
return await self.fs._exists(self._join(path))
|
| 229 |
+
|
| 230 |
+
def exists(self, path):
|
| 231 |
+
return self.fs.exists(self._join(path))
|
| 232 |
+
|
| 233 |
+
async def _info(self, path, **kwargs):
|
| 234 |
+
info = await self.fs._info(self._join(path), **kwargs)
|
| 235 |
+
info = info.copy()
|
| 236 |
+
info["name"] = self._relpath(info["name"])
|
| 237 |
+
return info
|
| 238 |
+
|
| 239 |
+
def info(self, path, **kwargs):
|
| 240 |
+
info = self.fs.info(self._join(path), **kwargs)
|
| 241 |
+
info = info.copy()
|
| 242 |
+
info["name"] = self._relpath(info["name"])
|
| 243 |
+
return info
|
| 244 |
+
|
| 245 |
+
async def _ls(self, path, detail=True, **kwargs):
|
| 246 |
+
ret = (await self.fs._ls(self._join(path), detail=detail, **kwargs)).copy()
|
| 247 |
+
if detail:
|
| 248 |
+
out = []
|
| 249 |
+
for entry in ret:
|
| 250 |
+
entry = entry.copy()
|
| 251 |
+
entry["name"] = self._relpath(entry["name"])
|
| 252 |
+
out.append(entry)
|
| 253 |
+
return out
|
| 254 |
+
|
| 255 |
+
return self._relpath(ret)
|
| 256 |
+
|
| 257 |
+
def ls(self, path, detail=True, **kwargs):
|
| 258 |
+
ret = self.fs.ls(self._join(path), detail=detail, **kwargs).copy()
|
| 259 |
+
if detail:
|
| 260 |
+
out = []
|
| 261 |
+
for entry in ret:
|
| 262 |
+
entry = entry.copy()
|
| 263 |
+
entry["name"] = self._relpath(entry["name"])
|
| 264 |
+
out.append(entry)
|
| 265 |
+
return out
|
| 266 |
+
|
| 267 |
+
return self._relpath(ret)
|
| 268 |
+
|
| 269 |
+
async def _walk(self, path, *args, **kwargs):
|
| 270 |
+
async for root, dirs, files in self.fs._walk(self._join(path), *args, **kwargs):
|
| 271 |
+
yield self._relpath(root), dirs, files
|
| 272 |
+
|
| 273 |
+
def walk(self, path, *args, **kwargs):
|
| 274 |
+
for root, dirs, files in self.fs.walk(self._join(path), *args, **kwargs):
|
| 275 |
+
yield self._relpath(root), dirs, files
|
| 276 |
+
|
| 277 |
+
async def _glob(self, path, **kwargs):
|
| 278 |
+
detail = kwargs.get("detail", False)
|
| 279 |
+
ret = await self.fs._glob(self._join(path), **kwargs)
|
| 280 |
+
if detail:
|
| 281 |
+
return {self._relpath(path): info for path, info in ret.items()}
|
| 282 |
+
return self._relpath(ret)
|
| 283 |
+
|
| 284 |
+
def glob(self, path, **kwargs):
|
| 285 |
+
detail = kwargs.get("detail", False)
|
| 286 |
+
ret = self.fs.glob(self._join(path), **kwargs)
|
| 287 |
+
if detail:
|
| 288 |
+
return {self._relpath(path): info for path, info in ret.items()}
|
| 289 |
+
return self._relpath(ret)
|
| 290 |
+
|
| 291 |
+
async def _du(self, path, *args, **kwargs):
|
| 292 |
+
total = kwargs.get("total", True)
|
| 293 |
+
ret = await self.fs._du(self._join(path), *args, **kwargs)
|
| 294 |
+
if total:
|
| 295 |
+
return ret
|
| 296 |
+
|
| 297 |
+
return {self._relpath(path): size for path, size in ret.items()}
|
| 298 |
+
|
| 299 |
+
def du(self, path, *args, **kwargs):
|
| 300 |
+
total = kwargs.get("total", True)
|
| 301 |
+
ret = self.fs.du(self._join(path), *args, **kwargs)
|
| 302 |
+
if total:
|
| 303 |
+
return ret
|
| 304 |
+
|
| 305 |
+
return {self._relpath(path): size for path, size in ret.items()}
|
| 306 |
+
|
| 307 |
+
async def _find(self, path, *args, **kwargs):
|
| 308 |
+
detail = kwargs.get("detail", False)
|
| 309 |
+
ret = await self.fs._find(self._join(path), *args, **kwargs)
|
| 310 |
+
if detail:
|
| 311 |
+
return {self._relpath(path): info for path, info in ret.items()}
|
| 312 |
+
return self._relpath(ret)
|
| 313 |
+
|
| 314 |
+
def find(self, path, *args, **kwargs):
|
| 315 |
+
detail = kwargs.get("detail", False)
|
| 316 |
+
ret = self.fs.find(self._join(path), *args, **kwargs)
|
| 317 |
+
if detail:
|
| 318 |
+
return {self._relpath(path): info for path, info in ret.items()}
|
| 319 |
+
return self._relpath(ret)
|
| 320 |
+
|
| 321 |
+
async def _expand_path(self, path, *args, **kwargs):
|
| 322 |
+
return self._relpath(
|
| 323 |
+
await self.fs._expand_path(self._join(path), *args, **kwargs)
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
def expand_path(self, path, *args, **kwargs):
|
| 327 |
+
return self._relpath(self.fs.expand_path(self._join(path), *args, **kwargs))
|
| 328 |
+
|
| 329 |
+
async def _mkdir(self, path, *args, **kwargs):
|
| 330 |
+
return await self.fs._mkdir(self._join(path), *args, **kwargs)
|
| 331 |
+
|
| 332 |
+
def mkdir(self, path, *args, **kwargs):
|
| 333 |
+
return self.fs.mkdir(self._join(path), *args, **kwargs)
|
| 334 |
+
|
| 335 |
+
async def _makedirs(self, path, *args, **kwargs):
|
| 336 |
+
return await self.fs._makedirs(self._join(path), *args, **kwargs)
|
| 337 |
+
|
| 338 |
+
def makedirs(self, path, *args, **kwargs):
|
| 339 |
+
return self.fs.makedirs(self._join(path), *args, **kwargs)
|
| 340 |
+
|
| 341 |
+
def rmdir(self, path):
|
| 342 |
+
return self.fs.rmdir(self._join(path))
|
| 343 |
+
|
| 344 |
+
def mv(self, path1, path2, **kwargs):
|
| 345 |
+
return self.fs.mv(
|
| 346 |
+
self._join(path1),
|
| 347 |
+
self._join(path2),
|
| 348 |
+
**kwargs,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
def touch(self, path, **kwargs):
|
| 352 |
+
return self.fs.touch(self._join(path), **kwargs)
|
| 353 |
+
|
| 354 |
+
def created(self, path):
|
| 355 |
+
return self.fs.created(self._join(path))
|
| 356 |
+
|
| 357 |
+
def modified(self, path):
|
| 358 |
+
return self.fs.modified(self._join(path))
|
| 359 |
+
|
| 360 |
+
def sign(self, path, *args, **kwargs):
|
| 361 |
+
return self.fs.sign(self._join(path), *args, **kwargs)
|
| 362 |
+
|
| 363 |
+
def __repr__(self):
|
| 364 |
+
return f"{self.__class__.__qualname__}(path='{self.path}', fs={self.fs})"
|
| 365 |
+
|
| 366 |
+
def open(
|
| 367 |
+
self,
|
| 368 |
+
path,
|
| 369 |
+
*args,
|
| 370 |
+
**kwargs,
|
| 371 |
+
):
|
| 372 |
+
return self.fs.open(
|
| 373 |
+
self._join(path),
|
| 374 |
+
*args,
|
| 375 |
+
**kwargs,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
async def open_async(
|
| 379 |
+
self,
|
| 380 |
+
path,
|
| 381 |
+
*args,
|
| 382 |
+
**kwargs,
|
| 383 |
+
):
|
| 384 |
+
return await self.fs.open_async(
|
| 385 |
+
self._join(path),
|
| 386 |
+
*args,
|
| 387 |
+
**kwargs,
|
| 388 |
+
)
|
pythonProject/.venv/Lib/site-packages/fsspec/utils.py
ADDED
|
@@ -0,0 +1,737 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import sys
|
| 9 |
+
import tempfile
|
| 10 |
+
from collections.abc import Iterable, Iterator, Sequence
|
| 11 |
+
from functools import partial
|
| 12 |
+
from hashlib import md5
|
| 13 |
+
from importlib.metadata import version
|
| 14 |
+
from typing import (
|
| 15 |
+
IO,
|
| 16 |
+
TYPE_CHECKING,
|
| 17 |
+
Any,
|
| 18 |
+
Callable,
|
| 19 |
+
TypeVar,
|
| 20 |
+
)
|
| 21 |
+
from urllib.parse import urlsplit
|
| 22 |
+
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
import pathlib
|
| 25 |
+
|
| 26 |
+
from typing_extensions import TypeGuard
|
| 27 |
+
|
| 28 |
+
from fsspec.spec import AbstractFileSystem
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
DEFAULT_BLOCK_SIZE = 5 * 2**20
|
| 32 |
+
|
| 33 |
+
T = TypeVar("T")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def infer_storage_options(
|
| 37 |
+
urlpath: str, inherit_storage_options: dict[str, Any] | None = None
|
| 38 |
+
) -> dict[str, Any]:
|
| 39 |
+
"""Infer storage options from URL path and merge it with existing storage
|
| 40 |
+
options.
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
urlpath: str or unicode
|
| 45 |
+
Either local absolute file path or URL (hdfs://namenode:8020/file.csv)
|
| 46 |
+
inherit_storage_options: dict (optional)
|
| 47 |
+
Its contents will get merged with the inferred information from the
|
| 48 |
+
given path
|
| 49 |
+
|
| 50 |
+
Returns
|
| 51 |
+
-------
|
| 52 |
+
Storage options dict.
|
| 53 |
+
|
| 54 |
+
Examples
|
| 55 |
+
--------
|
| 56 |
+
>>> infer_storage_options('/mnt/datasets/test.csv') # doctest: +SKIP
|
| 57 |
+
{"protocol": "file", "path", "/mnt/datasets/test.csv"}
|
| 58 |
+
>>> infer_storage_options(
|
| 59 |
+
... 'hdfs://username:pwd@node:123/mnt/datasets/test.csv?q=1',
|
| 60 |
+
... inherit_storage_options={'extra': 'value'},
|
| 61 |
+
... ) # doctest: +SKIP
|
| 62 |
+
{"protocol": "hdfs", "username": "username", "password": "pwd",
|
| 63 |
+
"host": "node", "port": 123, "path": "/mnt/datasets/test.csv",
|
| 64 |
+
"url_query": "q=1", "extra": "value"}
|
| 65 |
+
"""
|
| 66 |
+
# Handle Windows paths including disk name in this special case
|
| 67 |
+
if (
|
| 68 |
+
re.match(r"^[a-zA-Z]:[\\/]", urlpath)
|
| 69 |
+
or re.match(r"^[a-zA-Z0-9]+://", urlpath) is None
|
| 70 |
+
):
|
| 71 |
+
return {"protocol": "file", "path": urlpath}
|
| 72 |
+
|
| 73 |
+
parsed_path = urlsplit(urlpath)
|
| 74 |
+
protocol = parsed_path.scheme or "file"
|
| 75 |
+
if parsed_path.fragment:
|
| 76 |
+
path = "#".join([parsed_path.path, parsed_path.fragment])
|
| 77 |
+
else:
|
| 78 |
+
path = parsed_path.path
|
| 79 |
+
if protocol == "file":
|
| 80 |
+
# Special case parsing file protocol URL on Windows according to:
|
| 81 |
+
# https://msdn.microsoft.com/en-us/library/jj710207.aspx
|
| 82 |
+
windows_path = re.match(r"^/([a-zA-Z])[:|]([\\/].*)$", path)
|
| 83 |
+
if windows_path:
|
| 84 |
+
drive, path = windows_path.groups()
|
| 85 |
+
path = f"{drive}:{path}"
|
| 86 |
+
|
| 87 |
+
if protocol in ["http", "https"]:
|
| 88 |
+
# for HTTP, we don't want to parse, as requests will anyway
|
| 89 |
+
return {"protocol": protocol, "path": urlpath}
|
| 90 |
+
|
| 91 |
+
options: dict[str, Any] = {"protocol": protocol, "path": path}
|
| 92 |
+
|
| 93 |
+
if parsed_path.netloc:
|
| 94 |
+
# Parse `hostname` from netloc manually because `parsed_path.hostname`
|
| 95 |
+
# lowercases the hostname which is not always desirable (e.g. in S3):
|
| 96 |
+
# https://github.com/dask/dask/issues/1417
|
| 97 |
+
options["host"] = parsed_path.netloc.rsplit("@", 1)[-1].rsplit(":", 1)[0]
|
| 98 |
+
|
| 99 |
+
if protocol in ("s3", "s3a", "gcs", "gs"):
|
| 100 |
+
options["path"] = options["host"] + options["path"]
|
| 101 |
+
else:
|
| 102 |
+
options["host"] = options["host"]
|
| 103 |
+
if parsed_path.port:
|
| 104 |
+
options["port"] = parsed_path.port
|
| 105 |
+
if parsed_path.username:
|
| 106 |
+
options["username"] = parsed_path.username
|
| 107 |
+
if parsed_path.password:
|
| 108 |
+
options["password"] = parsed_path.password
|
| 109 |
+
|
| 110 |
+
if parsed_path.query:
|
| 111 |
+
options["url_query"] = parsed_path.query
|
| 112 |
+
if parsed_path.fragment:
|
| 113 |
+
options["url_fragment"] = parsed_path.fragment
|
| 114 |
+
|
| 115 |
+
if inherit_storage_options:
|
| 116 |
+
update_storage_options(options, inherit_storage_options)
|
| 117 |
+
|
| 118 |
+
return options
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def update_storage_options(
|
| 122 |
+
options: dict[str, Any], inherited: dict[str, Any] | None = None
|
| 123 |
+
) -> None:
|
| 124 |
+
if not inherited:
|
| 125 |
+
inherited = {}
|
| 126 |
+
collisions = set(options) & set(inherited)
|
| 127 |
+
if collisions:
|
| 128 |
+
for collision in collisions:
|
| 129 |
+
if options.get(collision) != inherited.get(collision):
|
| 130 |
+
raise KeyError(
|
| 131 |
+
f"Collision between inferred and specified storage "
|
| 132 |
+
f"option:\n{collision}"
|
| 133 |
+
)
|
| 134 |
+
options.update(inherited)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# Compression extensions registered via fsspec.compression.register_compression
|
| 138 |
+
compressions: dict[str, str] = {}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def infer_compression(filename: str) -> str | None:
|
| 142 |
+
"""Infer compression, if available, from filename.
|
| 143 |
+
|
| 144 |
+
Infer a named compression type, if registered and available, from filename
|
| 145 |
+
extension. This includes builtin (gz, bz2, zip) compressions, as well as
|
| 146 |
+
optional compressions. See fsspec.compression.register_compression.
|
| 147 |
+
"""
|
| 148 |
+
extension = os.path.splitext(filename)[-1].strip(".").lower()
|
| 149 |
+
if extension in compressions:
|
| 150 |
+
return compressions[extension]
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def build_name_function(max_int: float) -> Callable[[int], str]:
|
| 155 |
+
"""Returns a function that receives a single integer
|
| 156 |
+
and returns it as a string padded by enough zero characters
|
| 157 |
+
to align with maximum possible integer
|
| 158 |
+
|
| 159 |
+
>>> name_f = build_name_function(57)
|
| 160 |
+
|
| 161 |
+
>>> name_f(7)
|
| 162 |
+
'07'
|
| 163 |
+
>>> name_f(31)
|
| 164 |
+
'31'
|
| 165 |
+
>>> build_name_function(1000)(42)
|
| 166 |
+
'0042'
|
| 167 |
+
>>> build_name_function(999)(42)
|
| 168 |
+
'042'
|
| 169 |
+
>>> build_name_function(0)(0)
|
| 170 |
+
'0'
|
| 171 |
+
"""
|
| 172 |
+
# handle corner cases max_int is 0 or exact power of 10
|
| 173 |
+
max_int += 1e-8
|
| 174 |
+
|
| 175 |
+
pad_length = int(math.ceil(math.log10(max_int)))
|
| 176 |
+
|
| 177 |
+
def name_function(i: int) -> str:
|
| 178 |
+
return str(i).zfill(pad_length)
|
| 179 |
+
|
| 180 |
+
return name_function
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def seek_delimiter(file: IO[bytes], delimiter: bytes, blocksize: int) -> bool:
|
| 184 |
+
r"""Seek current file to file start, file end, or byte after delimiter seq.
|
| 185 |
+
|
| 186 |
+
Seeks file to next chunk delimiter, where chunks are defined on file start,
|
| 187 |
+
a delimiting sequence, and file end. Use file.tell() to see location afterwards.
|
| 188 |
+
Note that file start is a valid split, so must be at offset > 0 to seek for
|
| 189 |
+
delimiter.
|
| 190 |
+
|
| 191 |
+
Parameters
|
| 192 |
+
----------
|
| 193 |
+
file: a file
|
| 194 |
+
delimiter: bytes
|
| 195 |
+
a delimiter like ``b'\n'`` or message sentinel, matching file .read() type
|
| 196 |
+
blocksize: int
|
| 197 |
+
Number of bytes to read from the file at once.
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
Returns
|
| 201 |
+
-------
|
| 202 |
+
Returns True if a delimiter was found, False if at file start or end.
|
| 203 |
+
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
if file.tell() == 0:
|
| 207 |
+
# beginning-of-file, return without seek
|
| 208 |
+
return False
|
| 209 |
+
|
| 210 |
+
# Interface is for binary IO, with delimiter as bytes, but initialize last
|
| 211 |
+
# with result of file.read to preserve compatibility with text IO.
|
| 212 |
+
last: bytes | None = None
|
| 213 |
+
while True:
|
| 214 |
+
current = file.read(blocksize)
|
| 215 |
+
if not current:
|
| 216 |
+
# end-of-file without delimiter
|
| 217 |
+
return False
|
| 218 |
+
full = last + current if last else current
|
| 219 |
+
try:
|
| 220 |
+
if delimiter in full:
|
| 221 |
+
i = full.index(delimiter)
|
| 222 |
+
file.seek(file.tell() - (len(full) - i) + len(delimiter))
|
| 223 |
+
return True
|
| 224 |
+
elif len(current) < blocksize:
|
| 225 |
+
# end-of-file without delimiter
|
| 226 |
+
return False
|
| 227 |
+
except (OSError, ValueError):
|
| 228 |
+
pass
|
| 229 |
+
last = full[-len(delimiter) :]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def read_block(
|
| 233 |
+
f: IO[bytes],
|
| 234 |
+
offset: int,
|
| 235 |
+
length: int | None,
|
| 236 |
+
delimiter: bytes | None = None,
|
| 237 |
+
split_before: bool = False,
|
| 238 |
+
) -> bytes:
|
| 239 |
+
"""Read a block of bytes from a file
|
| 240 |
+
|
| 241 |
+
Parameters
|
| 242 |
+
----------
|
| 243 |
+
f: File
|
| 244 |
+
Open file
|
| 245 |
+
offset: int
|
| 246 |
+
Byte offset to start read
|
| 247 |
+
length: int
|
| 248 |
+
Number of bytes to read, read through end of file if None
|
| 249 |
+
delimiter: bytes (optional)
|
| 250 |
+
Ensure reading starts and stops at delimiter bytestring
|
| 251 |
+
split_before: bool (optional)
|
| 252 |
+
Start/stop read *before* delimiter bytestring.
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
If using the ``delimiter=`` keyword argument we ensure that the read
|
| 256 |
+
starts and stops at delimiter boundaries that follow the locations
|
| 257 |
+
``offset`` and ``offset + length``. If ``offset`` is zero then we
|
| 258 |
+
start at zero, regardless of delimiter. The bytestring returned WILL
|
| 259 |
+
include the terminating delimiter string.
|
| 260 |
+
|
| 261 |
+
Examples
|
| 262 |
+
--------
|
| 263 |
+
|
| 264 |
+
>>> from io import BytesIO # doctest: +SKIP
|
| 265 |
+
>>> f = BytesIO(b'Alice, 100\\nBob, 200\\nCharlie, 300') # doctest: +SKIP
|
| 266 |
+
>>> read_block(f, 0, 13) # doctest: +SKIP
|
| 267 |
+
b'Alice, 100\\nBo'
|
| 268 |
+
|
| 269 |
+
>>> read_block(f, 0, 13, delimiter=b'\\n') # doctest: +SKIP
|
| 270 |
+
b'Alice, 100\\nBob, 200\\n'
|
| 271 |
+
|
| 272 |
+
>>> read_block(f, 10, 10, delimiter=b'\\n') # doctest: +SKIP
|
| 273 |
+
b'Bob, 200\\nCharlie, 300'
|
| 274 |
+
"""
|
| 275 |
+
if delimiter:
|
| 276 |
+
f.seek(offset)
|
| 277 |
+
found_start_delim = seek_delimiter(f, delimiter, 2**16)
|
| 278 |
+
if length is None:
|
| 279 |
+
return f.read()
|
| 280 |
+
start = f.tell()
|
| 281 |
+
length -= start - offset
|
| 282 |
+
|
| 283 |
+
f.seek(start + length)
|
| 284 |
+
found_end_delim = seek_delimiter(f, delimiter, 2**16)
|
| 285 |
+
end = f.tell()
|
| 286 |
+
|
| 287 |
+
# Adjust split location to before delimiter if seek found the
|
| 288 |
+
# delimiter sequence, not start or end of file.
|
| 289 |
+
if found_start_delim and split_before:
|
| 290 |
+
start -= len(delimiter)
|
| 291 |
+
|
| 292 |
+
if found_end_delim and split_before:
|
| 293 |
+
end -= len(delimiter)
|
| 294 |
+
|
| 295 |
+
offset = start
|
| 296 |
+
length = end - start
|
| 297 |
+
|
| 298 |
+
f.seek(offset)
|
| 299 |
+
|
| 300 |
+
# TODO: allow length to be None and read to the end of the file?
|
| 301 |
+
assert length is not None
|
| 302 |
+
b = f.read(length)
|
| 303 |
+
return b
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def tokenize(*args: Any, **kwargs: Any) -> str:
|
| 307 |
+
"""Deterministic token
|
| 308 |
+
|
| 309 |
+
(modified from dask.base)
|
| 310 |
+
|
| 311 |
+
>>> tokenize([1, 2, '3'])
|
| 312 |
+
'9d71491b50023b06fc76928e6eddb952'
|
| 313 |
+
|
| 314 |
+
>>> tokenize('Hello') == tokenize('Hello')
|
| 315 |
+
True
|
| 316 |
+
"""
|
| 317 |
+
if kwargs:
|
| 318 |
+
args += (kwargs,)
|
| 319 |
+
try:
|
| 320 |
+
h = md5(str(args).encode())
|
| 321 |
+
except ValueError:
|
| 322 |
+
# FIPS systems: https://github.com/fsspec/filesystem_spec/issues/380
|
| 323 |
+
h = md5(str(args).encode(), usedforsecurity=False)
|
| 324 |
+
return h.hexdigest()
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def stringify_path(filepath: str | os.PathLike[str] | pathlib.Path) -> str:
|
| 328 |
+
"""Attempt to convert a path-like object to a string.
|
| 329 |
+
|
| 330 |
+
Parameters
|
| 331 |
+
----------
|
| 332 |
+
filepath: object to be converted
|
| 333 |
+
|
| 334 |
+
Returns
|
| 335 |
+
-------
|
| 336 |
+
filepath_str: maybe a string version of the object
|
| 337 |
+
|
| 338 |
+
Notes
|
| 339 |
+
-----
|
| 340 |
+
Objects supporting the fspath protocol are coerced according to its
|
| 341 |
+
__fspath__ method.
|
| 342 |
+
|
| 343 |
+
For backwards compatibility with older Python version, pathlib.Path
|
| 344 |
+
objects are specially coerced.
|
| 345 |
+
|
| 346 |
+
Any other object is passed through unchanged, which includes bytes,
|
| 347 |
+
strings, buffers, or anything else that's not even path-like.
|
| 348 |
+
"""
|
| 349 |
+
if isinstance(filepath, str):
|
| 350 |
+
return filepath
|
| 351 |
+
elif hasattr(filepath, "__fspath__"):
|
| 352 |
+
return filepath.__fspath__()
|
| 353 |
+
elif hasattr(filepath, "path"):
|
| 354 |
+
return filepath.path
|
| 355 |
+
else:
|
| 356 |
+
return filepath # type: ignore[return-value]
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def make_instance(
|
| 360 |
+
cls: Callable[..., T], args: Sequence[Any], kwargs: dict[str, Any]
|
| 361 |
+
) -> T:
|
| 362 |
+
inst = cls(*args, **kwargs)
|
| 363 |
+
inst._determine_worker() # type: ignore[attr-defined]
|
| 364 |
+
return inst
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def common_prefix(paths: Iterable[str]) -> str:
|
| 368 |
+
"""For a list of paths, find the shortest prefix common to all"""
|
| 369 |
+
parts = [p.split("/") for p in paths]
|
| 370 |
+
lmax = min(len(p) for p in parts)
|
| 371 |
+
end = 0
|
| 372 |
+
for i in range(lmax):
|
| 373 |
+
end = all(p[i] == parts[0][i] for p in parts)
|
| 374 |
+
if not end:
|
| 375 |
+
break
|
| 376 |
+
i += end
|
| 377 |
+
return "/".join(parts[0][:i])
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def other_paths(
|
| 381 |
+
paths: list[str],
|
| 382 |
+
path2: str | list[str],
|
| 383 |
+
exists: bool = False,
|
| 384 |
+
flatten: bool = False,
|
| 385 |
+
) -> list[str]:
|
| 386 |
+
"""In bulk file operations, construct a new file tree from a list of files
|
| 387 |
+
|
| 388 |
+
Parameters
|
| 389 |
+
----------
|
| 390 |
+
paths: list of str
|
| 391 |
+
The input file tree
|
| 392 |
+
path2: str or list of str
|
| 393 |
+
Root to construct the new list in. If this is already a list of str, we just
|
| 394 |
+
assert it has the right number of elements.
|
| 395 |
+
exists: bool (optional)
|
| 396 |
+
For a str destination, it is already exists (and is a dir), files should
|
| 397 |
+
end up inside.
|
| 398 |
+
flatten: bool (optional)
|
| 399 |
+
Whether to flatten the input directory tree structure so that the output files
|
| 400 |
+
are in the same directory.
|
| 401 |
+
|
| 402 |
+
Returns
|
| 403 |
+
-------
|
| 404 |
+
list of str
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
if isinstance(path2, str):
|
| 408 |
+
path2 = path2.rstrip("/")
|
| 409 |
+
|
| 410 |
+
if flatten:
|
| 411 |
+
path2 = ["/".join((path2, p.split("/")[-1])) for p in paths]
|
| 412 |
+
else:
|
| 413 |
+
cp = common_prefix(paths)
|
| 414 |
+
if exists:
|
| 415 |
+
cp = cp.rsplit("/", 1)[0]
|
| 416 |
+
if not cp and all(not s.startswith("/") for s in paths):
|
| 417 |
+
path2 = ["/".join([path2, p]) for p in paths]
|
| 418 |
+
else:
|
| 419 |
+
path2 = [p.replace(cp, path2, 1) for p in paths]
|
| 420 |
+
else:
|
| 421 |
+
assert len(paths) == len(path2)
|
| 422 |
+
return path2
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def is_exception(obj: Any) -> bool:
|
| 426 |
+
return isinstance(obj, BaseException)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def isfilelike(f: Any) -> TypeGuard[IO[bytes]]:
|
| 430 |
+
return all(hasattr(f, attr) for attr in ["read", "close", "tell"])
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def get_protocol(url: str) -> str:
|
| 434 |
+
url = stringify_path(url)
|
| 435 |
+
parts = re.split(r"(\:\:|\://)", url, maxsplit=1)
|
| 436 |
+
if len(parts) > 1:
|
| 437 |
+
return parts[0]
|
| 438 |
+
return "file"
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def can_be_local(path: str) -> bool:
|
| 442 |
+
"""Can the given URL be used with open_local?"""
|
| 443 |
+
from fsspec import get_filesystem_class
|
| 444 |
+
|
| 445 |
+
try:
|
| 446 |
+
return getattr(get_filesystem_class(get_protocol(path)), "local_file", False)
|
| 447 |
+
except (ValueError, ImportError):
|
| 448 |
+
# not in registry or import failed
|
| 449 |
+
return False
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def get_package_version_without_import(name: str) -> str | None:
|
| 453 |
+
"""For given package name, try to find the version without importing it
|
| 454 |
+
|
| 455 |
+
Import and package.__version__ is still the backup here, so an import
|
| 456 |
+
*might* happen.
|
| 457 |
+
|
| 458 |
+
Returns either the version string, or None if the package
|
| 459 |
+
or the version was not readily found.
|
| 460 |
+
"""
|
| 461 |
+
if name in sys.modules:
|
| 462 |
+
mod = sys.modules[name]
|
| 463 |
+
if hasattr(mod, "__version__"):
|
| 464 |
+
return mod.__version__
|
| 465 |
+
try:
|
| 466 |
+
return version(name)
|
| 467 |
+
except: # noqa: E722
|
| 468 |
+
pass
|
| 469 |
+
try:
|
| 470 |
+
import importlib
|
| 471 |
+
|
| 472 |
+
mod = importlib.import_module(name)
|
| 473 |
+
return mod.__version__
|
| 474 |
+
except (ImportError, AttributeError):
|
| 475 |
+
return None
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def setup_logging(
|
| 479 |
+
logger: logging.Logger | None = None,
|
| 480 |
+
logger_name: str | None = None,
|
| 481 |
+
level: str = "DEBUG",
|
| 482 |
+
clear: bool = True,
|
| 483 |
+
) -> logging.Logger:
|
| 484 |
+
if logger is None and logger_name is None:
|
| 485 |
+
raise ValueError("Provide either logger object or logger name")
|
| 486 |
+
logger = logger or logging.getLogger(logger_name)
|
| 487 |
+
handle = logging.StreamHandler()
|
| 488 |
+
formatter = logging.Formatter(
|
| 489 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(funcName)s -- %(message)s"
|
| 490 |
+
)
|
| 491 |
+
handle.setFormatter(formatter)
|
| 492 |
+
if clear:
|
| 493 |
+
logger.handlers.clear()
|
| 494 |
+
logger.addHandler(handle)
|
| 495 |
+
logger.setLevel(level)
|
| 496 |
+
return logger
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def _unstrip_protocol(name: str, fs: AbstractFileSystem) -> str:
|
| 500 |
+
return fs.unstrip_protocol(name)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def mirror_from(
|
| 504 |
+
origin_name: str, methods: Iterable[str]
|
| 505 |
+
) -> Callable[[type[T]], type[T]]:
|
| 506 |
+
"""Mirror attributes and methods from the given
|
| 507 |
+
origin_name attribute of the instance to the
|
| 508 |
+
decorated class"""
|
| 509 |
+
|
| 510 |
+
def origin_getter(method: str, self: Any) -> Any:
|
| 511 |
+
origin = getattr(self, origin_name)
|
| 512 |
+
return getattr(origin, method)
|
| 513 |
+
|
| 514 |
+
def wrapper(cls: type[T]) -> type[T]:
|
| 515 |
+
for method in methods:
|
| 516 |
+
wrapped_method = partial(origin_getter, method)
|
| 517 |
+
setattr(cls, method, property(wrapped_method))
|
| 518 |
+
return cls
|
| 519 |
+
|
| 520 |
+
return wrapper
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
@contextlib.contextmanager
|
| 524 |
+
def nullcontext(obj: T) -> Iterator[T]:
|
| 525 |
+
yield obj
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def merge_offset_ranges(
|
| 529 |
+
paths: list[str],
|
| 530 |
+
starts: list[int] | int,
|
| 531 |
+
ends: list[int] | int,
|
| 532 |
+
max_gap: int = 0,
|
| 533 |
+
max_block: int | None = None,
|
| 534 |
+
sort: bool = True,
|
| 535 |
+
) -> tuple[list[str], list[int], list[int]]:
|
| 536 |
+
"""Merge adjacent byte-offset ranges when the inter-range
|
| 537 |
+
gap is <= `max_gap`, and when the merged byte range does not
|
| 538 |
+
exceed `max_block` (if specified). By default, this function
|
| 539 |
+
will re-order the input paths and byte ranges to ensure sorted
|
| 540 |
+
order. If the user can guarantee that the inputs are already
|
| 541 |
+
sorted, passing `sort=False` will skip the re-ordering.
|
| 542 |
+
"""
|
| 543 |
+
# Check input
|
| 544 |
+
if not isinstance(paths, list):
|
| 545 |
+
raise TypeError
|
| 546 |
+
if not isinstance(starts, list):
|
| 547 |
+
starts = [starts] * len(paths)
|
| 548 |
+
if not isinstance(ends, list):
|
| 549 |
+
ends = [ends] * len(paths)
|
| 550 |
+
if len(starts) != len(paths) or len(ends) != len(paths):
|
| 551 |
+
raise ValueError
|
| 552 |
+
|
| 553 |
+
# Early Return
|
| 554 |
+
if len(starts) <= 1:
|
| 555 |
+
return paths, starts, ends
|
| 556 |
+
|
| 557 |
+
starts = [s or 0 for s in starts]
|
| 558 |
+
# Sort by paths and then ranges if `sort=True`
|
| 559 |
+
if sort:
|
| 560 |
+
paths, starts, ends = (
|
| 561 |
+
list(v)
|
| 562 |
+
for v in zip(
|
| 563 |
+
*sorted(
|
| 564 |
+
zip(paths, starts, ends),
|
| 565 |
+
)
|
| 566 |
+
)
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
if paths:
|
| 570 |
+
# Loop through the coupled `paths`, `starts`, and
|
| 571 |
+
# `ends`, and merge adjacent blocks when appropriate
|
| 572 |
+
new_paths = paths[:1]
|
| 573 |
+
new_starts = starts[:1]
|
| 574 |
+
new_ends = ends[:1]
|
| 575 |
+
for i in range(1, len(paths)):
|
| 576 |
+
if paths[i] == paths[i - 1] and new_ends[-1] is None:
|
| 577 |
+
continue
|
| 578 |
+
elif (
|
| 579 |
+
paths[i] != paths[i - 1]
|
| 580 |
+
or ((starts[i] - new_ends[-1]) > max_gap)
|
| 581 |
+
or (max_block is not None and (ends[i] - new_starts[-1]) > max_block)
|
| 582 |
+
):
|
| 583 |
+
# Cannot merge with previous block.
|
| 584 |
+
# Add new `paths`, `starts`, and `ends` elements
|
| 585 |
+
new_paths.append(paths[i])
|
| 586 |
+
new_starts.append(starts[i])
|
| 587 |
+
new_ends.append(ends[i])
|
| 588 |
+
else:
|
| 589 |
+
# Merge with previous block by updating the
|
| 590 |
+
# last element of `ends`
|
| 591 |
+
new_ends[-1] = ends[i]
|
| 592 |
+
return new_paths, new_starts, new_ends
|
| 593 |
+
|
| 594 |
+
# `paths` is empty. Just return input lists
|
| 595 |
+
return paths, starts, ends
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def file_size(filelike: IO[bytes]) -> int:
|
| 599 |
+
"""Find length of any open read-mode file-like"""
|
| 600 |
+
pos = filelike.tell()
|
| 601 |
+
try:
|
| 602 |
+
return filelike.seek(0, 2)
|
| 603 |
+
finally:
|
| 604 |
+
filelike.seek(pos)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
@contextlib.contextmanager
|
| 608 |
+
def atomic_write(path: str, mode: str = "wb"):
|
| 609 |
+
"""
|
| 610 |
+
A context manager that opens a temporary file next to `path` and, on exit,
|
| 611 |
+
replaces `path` with the temporary file, thereby updating `path`
|
| 612 |
+
atomically.
|
| 613 |
+
"""
|
| 614 |
+
fd, fn = tempfile.mkstemp(
|
| 615 |
+
dir=os.path.dirname(path), prefix=os.path.basename(path) + "-"
|
| 616 |
+
)
|
| 617 |
+
try:
|
| 618 |
+
with open(fd, mode) as fp:
|
| 619 |
+
yield fp
|
| 620 |
+
except BaseException:
|
| 621 |
+
with contextlib.suppress(FileNotFoundError):
|
| 622 |
+
os.unlink(fn)
|
| 623 |
+
raise
|
| 624 |
+
else:
|
| 625 |
+
os.replace(fn, path)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def _translate(pat, STAR, QUESTION_MARK):
|
| 629 |
+
# Copied from: https://github.com/python/cpython/pull/106703.
|
| 630 |
+
res: list[str] = []
|
| 631 |
+
add = res.append
|
| 632 |
+
i, n = 0, len(pat)
|
| 633 |
+
while i < n:
|
| 634 |
+
c = pat[i]
|
| 635 |
+
i = i + 1
|
| 636 |
+
if c == "*":
|
| 637 |
+
# compress consecutive `*` into one
|
| 638 |
+
if (not res) or res[-1] is not STAR:
|
| 639 |
+
add(STAR)
|
| 640 |
+
elif c == "?":
|
| 641 |
+
add(QUESTION_MARK)
|
| 642 |
+
elif c == "[":
|
| 643 |
+
j = i
|
| 644 |
+
if j < n and pat[j] == "!":
|
| 645 |
+
j = j + 1
|
| 646 |
+
if j < n and pat[j] == "]":
|
| 647 |
+
j = j + 1
|
| 648 |
+
while j < n and pat[j] != "]":
|
| 649 |
+
j = j + 1
|
| 650 |
+
if j >= n:
|
| 651 |
+
add("\\[")
|
| 652 |
+
else:
|
| 653 |
+
stuff = pat[i:j]
|
| 654 |
+
if "-" not in stuff:
|
| 655 |
+
stuff = stuff.replace("\\", r"\\")
|
| 656 |
+
else:
|
| 657 |
+
chunks = []
|
| 658 |
+
k = i + 2 if pat[i] == "!" else i + 1
|
| 659 |
+
while True:
|
| 660 |
+
k = pat.find("-", k, j)
|
| 661 |
+
if k < 0:
|
| 662 |
+
break
|
| 663 |
+
chunks.append(pat[i:k])
|
| 664 |
+
i = k + 1
|
| 665 |
+
k = k + 3
|
| 666 |
+
chunk = pat[i:j]
|
| 667 |
+
if chunk:
|
| 668 |
+
chunks.append(chunk)
|
| 669 |
+
else:
|
| 670 |
+
chunks[-1] += "-"
|
| 671 |
+
# Remove empty ranges -- invalid in RE.
|
| 672 |
+
for k in range(len(chunks) - 1, 0, -1):
|
| 673 |
+
if chunks[k - 1][-1] > chunks[k][0]:
|
| 674 |
+
chunks[k - 1] = chunks[k - 1][:-1] + chunks[k][1:]
|
| 675 |
+
del chunks[k]
|
| 676 |
+
# Escape backslashes and hyphens for set difference (--).
|
| 677 |
+
# Hyphens that create ranges shouldn't be escaped.
|
| 678 |
+
stuff = "-".join(
|
| 679 |
+
s.replace("\\", r"\\").replace("-", r"\-") for s in chunks
|
| 680 |
+
)
|
| 681 |
+
# Escape set operations (&&, ~~ and ||).
|
| 682 |
+
stuff = re.sub(r"([&~|])", r"\\\1", stuff)
|
| 683 |
+
i = j + 1
|
| 684 |
+
if not stuff:
|
| 685 |
+
# Empty range: never match.
|
| 686 |
+
add("(?!)")
|
| 687 |
+
elif stuff == "!":
|
| 688 |
+
# Negated empty range: match any character.
|
| 689 |
+
add(".")
|
| 690 |
+
else:
|
| 691 |
+
if stuff[0] == "!":
|
| 692 |
+
stuff = "^" + stuff[1:]
|
| 693 |
+
elif stuff[0] in ("^", "["):
|
| 694 |
+
stuff = "\\" + stuff
|
| 695 |
+
add(f"[{stuff}]")
|
| 696 |
+
else:
|
| 697 |
+
add(re.escape(c))
|
| 698 |
+
assert i == n
|
| 699 |
+
return res
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def glob_translate(pat):
|
| 703 |
+
# Copied from: https://github.com/python/cpython/pull/106703.
|
| 704 |
+
# The keyword parameters' values are fixed to:
|
| 705 |
+
# recursive=True, include_hidden=True, seps=None
|
| 706 |
+
"""Translate a pathname with shell wildcards to a regular expression."""
|
| 707 |
+
if os.path.altsep:
|
| 708 |
+
seps = os.path.sep + os.path.altsep
|
| 709 |
+
else:
|
| 710 |
+
seps = os.path.sep
|
| 711 |
+
escaped_seps = "".join(map(re.escape, seps))
|
| 712 |
+
any_sep = f"[{escaped_seps}]" if len(seps) > 1 else escaped_seps
|
| 713 |
+
not_sep = f"[^{escaped_seps}]"
|
| 714 |
+
one_last_segment = f"{not_sep}+"
|
| 715 |
+
one_segment = f"{one_last_segment}{any_sep}"
|
| 716 |
+
any_segments = f"(?:.+{any_sep})?"
|
| 717 |
+
any_last_segments = ".*"
|
| 718 |
+
results = []
|
| 719 |
+
parts = re.split(any_sep, pat)
|
| 720 |
+
last_part_idx = len(parts) - 1
|
| 721 |
+
for idx, part in enumerate(parts):
|
| 722 |
+
if part == "*":
|
| 723 |
+
results.append(one_segment if idx < last_part_idx else one_last_segment)
|
| 724 |
+
continue
|
| 725 |
+
if part == "**":
|
| 726 |
+
results.append(any_segments if idx < last_part_idx else any_last_segments)
|
| 727 |
+
continue
|
| 728 |
+
elif "**" in part:
|
| 729 |
+
raise ValueError(
|
| 730 |
+
"Invalid pattern: '**' can only be an entire path component"
|
| 731 |
+
)
|
| 732 |
+
if part:
|
| 733 |
+
results.extend(_translate(part, f"{not_sep}*", not_sep))
|
| 734 |
+
if idx < last_part_idx:
|
| 735 |
+
results.append(any_sep)
|
| 736 |
+
res = "".join(results)
|
| 737 |
+
return rf"(?s:{res})\Z"
|