xiaoanyu123 commited on
Commit
092c4a6
·
verified ·
1 Parent(s): bf8cf37

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  2. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +712 -0
  3. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_edm_euler.py +448 -0
  4. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +482 -0
  5. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_discrete.py +757 -0
  6. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_discrete_flax.py +265 -0
  7. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +561 -0
  8. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  9. pythonProject/.venv/Lib/site-packages/fsspec/__init__.py +71 -0
  10. pythonProject/.venv/Lib/site-packages/fsspec/_version.py +34 -0
  11. pythonProject/.venv/Lib/site-packages/fsspec/implementations/arrow.py +304 -0
  12. pythonProject/.venv/Lib/site-packages/fsspec/implementations/asyn_wrapper.py +122 -0
  13. pythonProject/.venv/Lib/site-packages/fsspec/implementations/cache_mapper.py +75 -0
  14. pythonProject/.venv/Lib/site-packages/fsspec/implementations/cache_metadata.py +233 -0
  15. pythonProject/.venv/Lib/site-packages/fsspec/implementations/cached.py +998 -0
  16. pythonProject/.venv/Lib/site-packages/fsspec/implementations/dask.py +152 -0
  17. pythonProject/.venv/Lib/site-packages/fsspec/implementations/data.py +58 -0
  18. pythonProject/.venv/Lib/site-packages/fsspec/implementations/dbfs.py +496 -0
  19. pythonProject/.venv/Lib/site-packages/fsspec/implementations/dirfs.py +388 -0
  20. pythonProject/.venv/Lib/site-packages/fsspec/utils.py +737 -0
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_dpm_cogvideox.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
17
+ # and https://github.com/hojonathanho/diffusion
18
+
19
+ import math
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ from ..configuration_utils import ConfigMixin, register_to_config
27
+ from ..utils import BaseOutput
28
+ from ..utils.torch_utils import randn_tensor
29
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
30
+
31
+
32
+ @dataclass
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
34
+ class DDIMSchedulerOutput(BaseOutput):
35
+ """
36
+ Output class for the scheduler's `step` function output.
37
+
38
+ Args:
39
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41
+ denoising loop.
42
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
43
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
44
+ `pred_original_sample` can be used to preview progress or for guidance.
45
+ """
46
+
47
+ prev_sample: torch.Tensor
48
+ pred_original_sample: Optional[torch.Tensor] = None
49
+
50
+
51
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
52
+ def betas_for_alpha_bar(
53
+ num_diffusion_timesteps,
54
+ max_beta=0.999,
55
+ alpha_transform_type="cosine",
56
+ ):
57
+ """
58
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
59
+ (1-beta) over time from t = [0,1].
60
+
61
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
62
+ to that part of the diffusion process.
63
+
64
+
65
+ Args:
66
+ num_diffusion_timesteps (`int`): the number of betas to produce.
67
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
68
+ prevent singularities.
69
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
70
+ Choose from `cosine` or `exp`
71
+
72
+ Returns:
73
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
74
+ """
75
+ if alpha_transform_type == "cosine":
76
+
77
+ def alpha_bar_fn(t):
78
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
79
+
80
+ elif alpha_transform_type == "exp":
81
+
82
+ def alpha_bar_fn(t):
83
+ return math.exp(t * -12.0)
84
+
85
+ else:
86
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
87
+
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
93
+ return torch.tensor(betas, dtype=torch.float32)
94
+
95
+
96
+ def rescale_zero_terminal_snr(alphas_cumprod):
97
+ """
98
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
99
+
100
+
101
+ Args:
102
+ betas (`torch.Tensor`):
103
+ the betas that the scheduler is being initialized with.
104
+
105
+ Returns:
106
+ `torch.Tensor`: rescaled betas with zero terminal SNR
107
+ """
108
+
109
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
110
+
111
+ # Store old values.
112
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
113
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
114
+
115
+ # Shift so the last timestep is zero.
116
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
117
+
118
+ # Scale so the first timestep is back to the old value.
119
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
120
+
121
+ # Convert alphas_bar_sqrt to betas
122
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
123
+
124
+ return alphas_bar
125
+
126
+
127
+ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
128
+ """
129
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
130
+ non-Markovian guidance.
131
+
132
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
133
+ methods the library implements for all schedulers such as loading and saving.
134
+
135
+ Args:
136
+ num_train_timesteps (`int`, defaults to 1000):
137
+ The number of diffusion steps to train the model.
138
+ beta_start (`float`, defaults to 0.0001):
139
+ The starting `beta` value of inference.
140
+ beta_end (`float`, defaults to 0.02):
141
+ The final `beta` value.
142
+ beta_schedule (`str`, defaults to `"linear"`):
143
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
144
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
145
+ trained_betas (`np.ndarray`, *optional*):
146
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
147
+ clip_sample (`bool`, defaults to `True`):
148
+ Clip the predicted sample for numerical stability.
149
+ clip_sample_range (`float`, defaults to 1.0):
150
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
151
+ set_alpha_to_one (`bool`, defaults to `True`):
152
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
153
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
154
+ otherwise it uses the alpha value at step 0.
155
+ steps_offset (`int`, defaults to 0):
156
+ An offset added to the inference steps, as required by some model families.
157
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
158
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
159
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
160
+ Video](https://imagen.research.google/video/paper.pdf) paper).
161
+ thresholding (`bool`, defaults to `False`):
162
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
163
+ as Stable Diffusion.
164
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
165
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
166
+ sample_max_value (`float`, defaults to 1.0):
167
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
168
+ timestep_spacing (`str`, defaults to `"leading"`):
169
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
170
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
171
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
172
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
173
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
174
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
175
+ """
176
+
177
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
178
+ order = 1
179
+
180
+ @register_to_config
181
+ def __init__(
182
+ self,
183
+ num_train_timesteps: int = 1000,
184
+ beta_start: float = 0.00085,
185
+ beta_end: float = 0.0120,
186
+ beta_schedule: str = "scaled_linear",
187
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
188
+ clip_sample: bool = True,
189
+ set_alpha_to_one: bool = True,
190
+ steps_offset: int = 0,
191
+ prediction_type: str = "epsilon",
192
+ clip_sample_range: float = 1.0,
193
+ sample_max_value: float = 1.0,
194
+ timestep_spacing: str = "leading",
195
+ rescale_betas_zero_snr: bool = False,
196
+ snr_shift_scale: float = 3.0,
197
+ ):
198
+ if trained_betas is not None:
199
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
200
+ elif beta_schedule == "linear":
201
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
202
+ elif beta_schedule == "scaled_linear":
203
+ # this schedule is very specific to the latent diffusion model.
204
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
205
+ elif beta_schedule == "squaredcos_cap_v2":
206
+ # Glide cosine schedule
207
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
208
+ else:
209
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
210
+
211
+ self.alphas = 1.0 - self.betas
212
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
213
+
214
+ # Modify: SNR shift following SD3
215
+ self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
216
+
217
+ # Rescale for zero SNR
218
+ if rescale_betas_zero_snr:
219
+ self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
220
+
221
+ # At every step in ddim, we are looking into the previous alphas_cumprod
222
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
223
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
224
+ # whether we use the final alpha of the "non-previous" one.
225
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
226
+
227
+ # standard deviation of the initial noise distribution
228
+ self.init_noise_sigma = 1.0
229
+
230
+ # setable values
231
+ self.num_inference_steps = None
232
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
233
+
234
+ def _get_variance(self, timestep, prev_timestep):
235
+ alpha_prod_t = self.alphas_cumprod[timestep]
236
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
237
+ beta_prod_t = 1 - alpha_prod_t
238
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
239
+
240
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
241
+
242
+ return variance
243
+
244
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
245
+ """
246
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
247
+ current timestep.
248
+
249
+ Args:
250
+ sample (`torch.Tensor`):
251
+ The input sample.
252
+ timestep (`int`, *optional*):
253
+ The current timestep in the diffusion chain.
254
+
255
+ Returns:
256
+ `torch.Tensor`:
257
+ A scaled input sample.
258
+ """
259
+ return sample
260
+
261
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
262
+ """
263
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
264
+
265
+ Args:
266
+ num_inference_steps (`int`):
267
+ The number of diffusion steps used when generating samples with a pre-trained model.
268
+ """
269
+
270
+ if num_inference_steps > self.config.num_train_timesteps:
271
+ raise ValueError(
272
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
273
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
274
+ f" maximal {self.config.num_train_timesteps} timesteps."
275
+ )
276
+
277
+ self.num_inference_steps = num_inference_steps
278
+
279
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
280
+ if self.config.timestep_spacing == "linspace":
281
+ timesteps = (
282
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
283
+ .round()[::-1]
284
+ .copy()
285
+ .astype(np.int64)
286
+ )
287
+ elif self.config.timestep_spacing == "leading":
288
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
289
+ # creates integer timesteps by multiplying by ratio
290
+ # casting to int to avoid issues when num_inference_step is power of 3
291
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
292
+ timesteps += self.config.steps_offset
293
+ elif self.config.timestep_spacing == "trailing":
294
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
295
+ # creates integer timesteps by multiplying by ratio
296
+ # casting to int to avoid issues when num_inference_step is power of 3
297
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
298
+ timesteps -= 1
299
+ else:
300
+ raise ValueError(
301
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
302
+ )
303
+
304
+ self.timesteps = torch.from_numpy(timesteps).to(device)
305
+
306
+ def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
307
+ lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
308
+ lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
309
+ h = lamb_next - lamb
310
+
311
+ if alpha_prod_t_back is not None:
312
+ lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log()
313
+ h_last = lamb - lamb_previous
314
+ r = h_last / h
315
+ return h, r, lamb, lamb_next
316
+ else:
317
+ return h, None, lamb, lamb_next
318
+
319
+ def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
320
+ mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
321
+ mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5
322
+
323
+ if alpha_prod_t_back is not None:
324
+ mult3 = 1 + 1 / (2 * r)
325
+ mult4 = 1 / (2 * r)
326
+ return mult1, mult2, mult3, mult4
327
+ else:
328
+ return mult1, mult2
329
+
330
+ def step(
331
+ self,
332
+ model_output: torch.Tensor,
333
+ old_pred_original_sample: torch.Tensor,
334
+ timestep: int,
335
+ timestep_back: int,
336
+ sample: torch.Tensor,
337
+ eta: float = 0.0,
338
+ use_clipped_model_output: bool = False,
339
+ generator=None,
340
+ variance_noise: Optional[torch.Tensor] = None,
341
+ return_dict: bool = False,
342
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
343
+ """
344
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
345
+ process from the learned model outputs (most often the predicted noise).
346
+
347
+ Args:
348
+ model_output (`torch.Tensor`):
349
+ The direct output from learned diffusion model.
350
+ timestep (`float`):
351
+ The current discrete timestep in the diffusion chain.
352
+ sample (`torch.Tensor`):
353
+ A current instance of a sample created by the diffusion process.
354
+ eta (`float`):
355
+ The weight of noise for added noise in diffusion step.
356
+ use_clipped_model_output (`bool`, defaults to `False`):
357
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
358
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
359
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
360
+ `use_clipped_model_output` has no effect.
361
+ generator (`torch.Generator`, *optional*):
362
+ A random number generator.
363
+ variance_noise (`torch.Tensor`):
364
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
365
+ itself. Useful for methods such as [`CycleDiffusion`].
366
+ return_dict (`bool`, *optional*, defaults to `True`):
367
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
368
+
369
+ Returns:
370
+ [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
371
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
372
+ tuple is returned where the first element is the sample tensor.
373
+
374
+ """
375
+ if self.num_inference_steps is None:
376
+ raise ValueError(
377
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
378
+ )
379
+
380
+ # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
381
+ # Ideally, read DDIM paper in-detail understanding
382
+
383
+ # Notation (<variable name> -> <name in paper>
384
+ # - pred_noise_t -> e_theta(x_t, t)
385
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
386
+ # - std_dev_t -> sigma_t
387
+ # - eta -> η
388
+ # - pred_sample_direction -> "direction pointing to x_t"
389
+ # - pred_prev_sample -> "x_t-1"
390
+
391
+ # 1. get previous step value (=t-1)
392
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
393
+
394
+ # 2. compute alphas, betas
395
+ alpha_prod_t = self.alphas_cumprod[timestep]
396
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
397
+ alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None
398
+
399
+ beta_prod_t = 1 - alpha_prod_t
400
+
401
+ # 3. compute predicted original sample from predicted noise also called
402
+ # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
403
+ # To make style tests pass, commented out `pred_epsilon` as it is an unused variable
404
+ if self.config.prediction_type == "epsilon":
405
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
406
+ # pred_epsilon = model_output
407
+ elif self.config.prediction_type == "sample":
408
+ pred_original_sample = model_output
409
+ # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
410
+ elif self.config.prediction_type == "v_prediction":
411
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
412
+ # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
413
+ else:
414
+ raise ValueError(
415
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
416
+ " `v_prediction`"
417
+ )
418
+
419
+ h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)
420
+ mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back))
421
+ mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5
422
+
423
+ noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
424
+ prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * noise
425
+
426
+ if old_pred_original_sample is None or prev_timestep < 0:
427
+ # Save a network evaluation if all noise levels are 0 or on the first step
428
+ return prev_sample, pred_original_sample
429
+ else:
430
+ denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample
431
+ noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
432
+ x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise
433
+
434
+ prev_sample = x_advanced
435
+
436
+ if not return_dict:
437
+ return (prev_sample, pred_original_sample)
438
+
439
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
440
+
441
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
442
+ def add_noise(
443
+ self,
444
+ original_samples: torch.Tensor,
445
+ noise: torch.Tensor,
446
+ timesteps: torch.IntTensor,
447
+ ) -> torch.Tensor:
448
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
449
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
450
+ # for the subsequent add_noise calls
451
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
452
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
453
+ timesteps = timesteps.to(original_samples.device)
454
+
455
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
456
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
457
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
458
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
459
+
460
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
461
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
462
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
463
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
464
+
465
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
466
+ return noisy_samples
467
+
468
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
469
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
470
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
471
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
472
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
473
+ timesteps = timesteps.to(sample.device)
474
+
475
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
476
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
477
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
478
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
479
+
480
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
481
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
482
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
483
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
484
+
485
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
486
+ return velocity
487
+
488
+ def __len__(self):
489
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from ..configuration_utils import ConfigMixin, register_to_config
24
+ from ..utils.torch_utils import randn_tensor
25
+ from .scheduling_utils import SchedulerMixin, SchedulerOutput
26
+
27
+
28
+ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
29
+ """
30
+ Implements DPMSolverMultistepScheduler in EDM formulation as presented in Karras et al. 2022 [1].
31
+ `EDMDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
32
+
33
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
34
+ https://huggingface.co/papers/2206.00364
35
+
36
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
37
+ methods the library implements for all schedulers such as loading and saving.
38
+
39
+ Args:
40
+ sigma_min (`float`, *optional*, defaults to 0.002):
41
+ Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable
42
+ range is [0, 10].
43
+ sigma_max (`float`, *optional*, defaults to 80.0):
44
+ Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable
45
+ range is [0.2, 80.0].
46
+ sigma_data (`float`, *optional*, defaults to 0.5):
47
+ The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
48
+ sigma_schedule (`str`, *optional*, defaults to `karras`):
49
+ Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
50
+ (https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
51
+ schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ solver_order (`int`, defaults to 2):
55
+ The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
56
+ sampling, and `solver_order=3` for unconditional sampling.
57
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
58
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
59
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
60
+ Video](https://imagen.research.google/video/paper.pdf) paper).
61
+ thresholding (`bool`, defaults to `False`):
62
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
63
+ as Stable Diffusion.
64
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
65
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
66
+ sample_max_value (`float`, defaults to 1.0):
67
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
68
+ `algorithm_type="dpmsolver++"`.
69
+ algorithm_type (`str`, defaults to `dpmsolver++`):
70
+ Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver++` type implements
71
+ the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to
72
+ use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
73
+ solver_type (`str`, defaults to `midpoint`):
74
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
75
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
76
+ lower_order_final (`bool`, defaults to `True`):
77
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
78
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
79
+ euler_at_final (`bool`, defaults to `False`):
80
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
81
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
82
+ steps, but sometimes may result in blurring.
83
+ final_sigmas_type (`str`, defaults to `"zero"`):
84
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
85
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
86
+ """
87
+
88
+ _compatibles = []
89
+ order = 1
90
+
91
+ @register_to_config
92
+ def __init__(
93
+ self,
94
+ sigma_min: float = 0.002,
95
+ sigma_max: float = 80.0,
96
+ sigma_data: float = 0.5,
97
+ sigma_schedule: str = "karras",
98
+ num_train_timesteps: int = 1000,
99
+ prediction_type: str = "epsilon",
100
+ rho: float = 7.0,
101
+ solver_order: int = 2,
102
+ thresholding: bool = False,
103
+ dynamic_thresholding_ratio: float = 0.995,
104
+ sample_max_value: float = 1.0,
105
+ algorithm_type: str = "dpmsolver++",
106
+ solver_type: str = "midpoint",
107
+ lower_order_final: bool = True,
108
+ euler_at_final: bool = False,
109
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
110
+ ):
111
+ # settings for DPM-Solver
112
+ if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]:
113
+ if algorithm_type == "deis":
114
+ self.register_to_config(algorithm_type="dpmsolver++")
115
+ else:
116
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
117
+
118
+ if solver_type not in ["midpoint", "heun"]:
119
+ if solver_type in ["logrho", "bh1", "bh2"]:
120
+ self.register_to_config(solver_type="midpoint")
121
+ else:
122
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
123
+
124
+ if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
125
+ raise ValueError(
126
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
127
+ )
128
+
129
+ ramp = torch.linspace(0, 1, num_train_timesteps)
130
+ if sigma_schedule == "karras":
131
+ sigmas = self._compute_karras_sigmas(ramp)
132
+ elif sigma_schedule == "exponential":
133
+ sigmas = self._compute_exponential_sigmas(ramp)
134
+
135
+ self.timesteps = self.precondition_noise(sigmas)
136
+
137
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
138
+
139
+ # setable values
140
+ self.num_inference_steps = None
141
+ self.model_outputs = [None] * solver_order
142
+ self.lower_order_nums = 0
143
+ self._step_index = None
144
+ self._begin_index = None
145
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
146
+
147
+ @property
148
+ def init_noise_sigma(self):
149
+ # standard deviation of the initial noise distribution
150
+ return (self.config.sigma_max**2 + 1) ** 0.5
151
+
152
+ @property
153
+ def step_index(self):
154
+ """
155
+ The index counter for current timestep. It will increase 1 after each scheduler step.
156
+ """
157
+ return self._step_index
158
+
159
+ @property
160
+ def begin_index(self):
161
+ """
162
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
163
+ """
164
+ return self._begin_index
165
+
166
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
167
+ def set_begin_index(self, begin_index: int = 0):
168
+ """
169
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
170
+
171
+ Args:
172
+ begin_index (`int`):
173
+ The begin index for the scheduler.
174
+ """
175
+ self._begin_index = begin_index
176
+
177
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
178
+ def precondition_inputs(self, sample, sigma):
179
+ c_in = self._get_conditioning_c_in(sigma)
180
+ scaled_sample = sample * c_in
181
+ return scaled_sample
182
+
183
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise
184
+ def precondition_noise(self, sigma):
185
+ if not isinstance(sigma, torch.Tensor):
186
+ sigma = torch.tensor([sigma])
187
+
188
+ c_noise = 0.25 * torch.log(sigma)
189
+
190
+ return c_noise
191
+
192
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
193
+ def precondition_outputs(self, sample, model_output, sigma):
194
+ sigma_data = self.config.sigma_data
195
+ c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
196
+
197
+ if self.config.prediction_type == "epsilon":
198
+ c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
199
+ elif self.config.prediction_type == "v_prediction":
200
+ c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
201
+ else:
202
+ raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
203
+
204
+ denoised = c_skip * sample + c_out * model_output
205
+
206
+ return denoised
207
+
208
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
209
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
210
+ """
211
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
212
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
213
+
214
+ Args:
215
+ sample (`torch.Tensor`):
216
+ The input sample.
217
+ timestep (`int`, *optional*):
218
+ The current timestep in the diffusion chain.
219
+
220
+ Returns:
221
+ `torch.Tensor`:
222
+ A scaled input sample.
223
+ """
224
+ if self.step_index is None:
225
+ self._init_step_index(timestep)
226
+
227
+ sigma = self.sigmas[self.step_index]
228
+ sample = self.precondition_inputs(sample, sigma)
229
+
230
+ self.is_scale_input_called = True
231
+ return sample
232
+
233
+ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
234
+ """
235
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
236
+
237
+ Args:
238
+ num_inference_steps (`int`):
239
+ The number of diffusion steps used when generating samples with a pre-trained model.
240
+ device (`str` or `torch.device`, *optional*):
241
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
242
+ """
243
+
244
+ self.num_inference_steps = num_inference_steps
245
+
246
+ ramp = torch.linspace(0, 1, self.num_inference_steps)
247
+ if self.config.sigma_schedule == "karras":
248
+ sigmas = self._compute_karras_sigmas(ramp)
249
+ elif self.config.sigma_schedule == "exponential":
250
+ sigmas = self._compute_exponential_sigmas(ramp)
251
+
252
+ sigmas = sigmas.to(dtype=torch.float32, device=device)
253
+ self.timesteps = self.precondition_noise(sigmas)
254
+
255
+ if self.config.final_sigmas_type == "sigma_min":
256
+ sigma_last = self.config.sigma_min
257
+ elif self.config.final_sigmas_type == "zero":
258
+ sigma_last = 0
259
+ else:
260
+ raise ValueError(
261
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
262
+ )
263
+
264
+ self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)])
265
+
266
+ self.model_outputs = [
267
+ None,
268
+ ] * self.config.solver_order
269
+ self.lower_order_nums = 0
270
+
271
+ # add an index counter for schedulers that allow duplicated timesteps
272
+ self._step_index = None
273
+ self._begin_index = None
274
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
275
+
276
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
277
+ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
278
+ """Constructs the noise schedule of Karras et al. (2022)."""
279
+ sigma_min = sigma_min or self.config.sigma_min
280
+ sigma_max = sigma_max or self.config.sigma_max
281
+
282
+ rho = self.config.rho
283
+ min_inv_rho = sigma_min ** (1 / rho)
284
+ max_inv_rho = sigma_max ** (1 / rho)
285
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
286
+ return sigmas
287
+
288
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
289
+ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
290
+ """Implementation closely follows k-diffusion.
291
+
292
+ https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
293
+ """
294
+ sigma_min = sigma_min or self.config.sigma_min
295
+ sigma_max = sigma_max or self.config.sigma_max
296
+ sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
297
+ return sigmas
298
+
299
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
300
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
301
+ """
302
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
303
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
304
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
305
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
306
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
307
+
308
+ https://huggingface.co/papers/2205.11487
309
+ """
310
+ dtype = sample.dtype
311
+ batch_size, channels, *remaining_dims = sample.shape
312
+
313
+ if dtype not in (torch.float32, torch.float64):
314
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
315
+
316
+ # Flatten sample for doing quantile calculation along each image
317
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
318
+
319
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
320
+
321
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
322
+ s = torch.clamp(
323
+ s, min=1, max=self.config.sample_max_value
324
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
325
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
326
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
327
+
328
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
329
+ sample = sample.to(dtype)
330
+
331
+ return sample
332
+
333
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
334
+ def _sigma_to_t(self, sigma, log_sigmas):
335
+ # get log sigma
336
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
337
+
338
+ # get distribution
339
+ dists = log_sigma - log_sigmas[:, np.newaxis]
340
+
341
+ # get sigmas range
342
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
343
+ high_idx = low_idx + 1
344
+
345
+ low = log_sigmas[low_idx]
346
+ high = log_sigmas[high_idx]
347
+
348
+ # interpolate sigmas
349
+ w = (low - log_sigma) / (low - high)
350
+ w = np.clip(w, 0, 1)
351
+
352
+ # transform interpolation to time range
353
+ t = (1 - w) * low_idx + w * high_idx
354
+ t = t.reshape(sigma.shape)
355
+ return t
356
+
357
+ def _sigma_to_alpha_sigma_t(self, sigma):
358
+ alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
359
+ sigma_t = sigma
360
+
361
+ return alpha_t, sigma_t
362
+
363
+ def convert_model_output(
364
+ self,
365
+ model_output: torch.Tensor,
366
+ sample: torch.Tensor = None,
367
+ ) -> torch.Tensor:
368
+ """
369
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
370
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
371
+ integral of the data prediction model.
372
+
373
+ <Tip>
374
+
375
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
376
+ prediction and data prediction models.
377
+
378
+ </Tip>
379
+
380
+ Args:
381
+ model_output (`torch.Tensor`):
382
+ The direct output from the learned diffusion model.
383
+ sample (`torch.Tensor`):
384
+ A current instance of a sample created by the diffusion process.
385
+
386
+ Returns:
387
+ `torch.Tensor`:
388
+ The converted model output.
389
+ """
390
+ sigma = self.sigmas[self.step_index]
391
+ x0_pred = self.precondition_outputs(sample, model_output, sigma)
392
+
393
+ if self.config.thresholding:
394
+ x0_pred = self._threshold_sample(x0_pred)
395
+
396
+ return x0_pred
397
+
398
+ def dpm_solver_first_order_update(
399
+ self,
400
+ model_output: torch.Tensor,
401
+ sample: torch.Tensor = None,
402
+ noise: Optional[torch.Tensor] = None,
403
+ ) -> torch.Tensor:
404
+ """
405
+ One step for the first-order DPMSolver (equivalent to DDIM).
406
+
407
+ Args:
408
+ model_output (`torch.Tensor`):
409
+ The direct output from the learned diffusion model.
410
+ sample (`torch.Tensor`):
411
+ A current instance of a sample created by the diffusion process.
412
+
413
+ Returns:
414
+ `torch.Tensor`:
415
+ The sample tensor at the previous timestep.
416
+ """
417
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
418
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
419
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
420
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
421
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
422
+
423
+ h = lambda_t - lambda_s
424
+ if self.config.algorithm_type == "dpmsolver++":
425
+ x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
426
+ elif self.config.algorithm_type == "sde-dpmsolver++":
427
+ assert noise is not None
428
+ x_t = (
429
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
430
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
431
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
432
+ )
433
+
434
+ return x_t
435
+
436
+ def multistep_dpm_solver_second_order_update(
437
+ self,
438
+ model_output_list: List[torch.Tensor],
439
+ sample: torch.Tensor = None,
440
+ noise: Optional[torch.Tensor] = None,
441
+ ) -> torch.Tensor:
442
+ """
443
+ One step for the second-order multistep DPMSolver.
444
+
445
+ Args:
446
+ model_output_list (`List[torch.Tensor]`):
447
+ The direct outputs from learned diffusion model at current and latter timesteps.
448
+ sample (`torch.Tensor`):
449
+ A current instance of a sample created by the diffusion process.
450
+
451
+ Returns:
452
+ `torch.Tensor`:
453
+ The sample tensor at the previous timestep.
454
+ """
455
+ sigma_t, sigma_s0, sigma_s1 = (
456
+ self.sigmas[self.step_index + 1],
457
+ self.sigmas[self.step_index],
458
+ self.sigmas[self.step_index - 1],
459
+ )
460
+
461
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
462
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
463
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
464
+
465
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
466
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
467
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
468
+
469
+ m0, m1 = model_output_list[-1], model_output_list[-2]
470
+
471
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
472
+ r0 = h_0 / h
473
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
474
+ if self.config.algorithm_type == "dpmsolver++":
475
+ # See https://huggingface.co/papers/2211.01095 for detailed derivations
476
+ if self.config.solver_type == "midpoint":
477
+ x_t = (
478
+ (sigma_t / sigma_s0) * sample
479
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
480
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
481
+ )
482
+ elif self.config.solver_type == "heun":
483
+ x_t = (
484
+ (sigma_t / sigma_s0) * sample
485
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
486
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
487
+ )
488
+ elif self.config.algorithm_type == "sde-dpmsolver++":
489
+ assert noise is not None
490
+ if self.config.solver_type == "midpoint":
491
+ x_t = (
492
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
493
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
494
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
495
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
496
+ )
497
+ elif self.config.solver_type == "heun":
498
+ x_t = (
499
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
500
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
501
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
502
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
503
+ )
504
+
505
+ return x_t
506
+
507
+ def multistep_dpm_solver_third_order_update(
508
+ self,
509
+ model_output_list: List[torch.Tensor],
510
+ sample: torch.Tensor = None,
511
+ ) -> torch.Tensor:
512
+ """
513
+ One step for the third-order multistep DPMSolver.
514
+
515
+ Args:
516
+ model_output_list (`List[torch.Tensor]`):
517
+ The direct outputs from learned diffusion model at current and latter timesteps.
518
+ sample (`torch.Tensor`):
519
+ A current instance of a sample created by diffusion process.
520
+
521
+ Returns:
522
+ `torch.Tensor`:
523
+ The sample tensor at the previous timestep.
524
+ """
525
+ sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
526
+ self.sigmas[self.step_index + 1],
527
+ self.sigmas[self.step_index],
528
+ self.sigmas[self.step_index - 1],
529
+ self.sigmas[self.step_index - 2],
530
+ )
531
+
532
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
533
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
534
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
535
+ alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
536
+
537
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
538
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
539
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
540
+ lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
541
+
542
+ m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
543
+
544
+ h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
545
+ r0, r1 = h_0 / h, h_1 / h
546
+ D0 = m0
547
+ D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
548
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
549
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
550
+ if self.config.algorithm_type == "dpmsolver++":
551
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
552
+ x_t = (
553
+ (sigma_t / sigma_s0) * sample
554
+ - (alpha_t * (torch.exp(-h) - 1.0)) * D0
555
+ + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
556
+ - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
557
+ )
558
+
559
+ return x_t
560
+
561
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
562
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
563
+ if schedule_timesteps is None:
564
+ schedule_timesteps = self.timesteps
565
+
566
+ index_candidates = (schedule_timesteps == timestep).nonzero()
567
+
568
+ if len(index_candidates) == 0:
569
+ step_index = len(self.timesteps) - 1
570
+ # The sigma index that is taken for the **very** first `step`
571
+ # is always the second index (or the last index if there is only 1)
572
+ # This way we can ensure we don't accidentally skip a sigma in
573
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
574
+ elif len(index_candidates) > 1:
575
+ step_index = index_candidates[1].item()
576
+ else:
577
+ step_index = index_candidates[0].item()
578
+
579
+ return step_index
580
+
581
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
582
+ def _init_step_index(self, timestep):
583
+ """
584
+ Initialize the step_index counter for the scheduler.
585
+ """
586
+
587
+ if self.begin_index is None:
588
+ if isinstance(timestep, torch.Tensor):
589
+ timestep = timestep.to(self.timesteps.device)
590
+ self._step_index = self.index_for_timestep(timestep)
591
+ else:
592
+ self._step_index = self._begin_index
593
+
594
+ def step(
595
+ self,
596
+ model_output: torch.Tensor,
597
+ timestep: Union[int, torch.Tensor],
598
+ sample: torch.Tensor,
599
+ generator=None,
600
+ return_dict: bool = True,
601
+ ) -> Union[SchedulerOutput, Tuple]:
602
+ """
603
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
604
+ the multistep DPMSolver.
605
+
606
+ Args:
607
+ model_output (`torch.Tensor`):
608
+ The direct output from learned diffusion model.
609
+ timestep (`int`):
610
+ The current discrete timestep in the diffusion chain.
611
+ sample (`torch.Tensor`):
612
+ A current instance of a sample created by the diffusion process.
613
+ generator (`torch.Generator`, *optional*):
614
+ A random number generator.
615
+ return_dict (`bool`):
616
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
617
+
618
+ Returns:
619
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
620
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
621
+ tuple is returned where the first element is the sample tensor.
622
+
623
+ """
624
+ if self.num_inference_steps is None:
625
+ raise ValueError(
626
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
627
+ )
628
+
629
+ if self.step_index is None:
630
+ self._init_step_index(timestep)
631
+
632
+ # Improve numerical stability for small number of steps
633
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
634
+ self.config.euler_at_final
635
+ or (self.config.lower_order_final and len(self.timesteps) < 15)
636
+ or self.config.final_sigmas_type == "zero"
637
+ )
638
+ lower_order_second = (
639
+ (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
640
+ )
641
+
642
+ model_output = self.convert_model_output(model_output, sample=sample)
643
+ for i in range(self.config.solver_order - 1):
644
+ self.model_outputs[i] = self.model_outputs[i + 1]
645
+ self.model_outputs[-1] = model_output
646
+
647
+ if self.config.algorithm_type == "sde-dpmsolver++":
648
+ noise = randn_tensor(
649
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
650
+ )
651
+ else:
652
+ noise = None
653
+
654
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
655
+ prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
656
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
657
+ prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
658
+ else:
659
+ prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
660
+
661
+ if self.lower_order_nums < self.config.solver_order:
662
+ self.lower_order_nums += 1
663
+
664
+ # upon completion increase step index by one
665
+ self._step_index += 1
666
+
667
+ if not return_dict:
668
+ return (prev_sample,)
669
+
670
+ return SchedulerOutput(prev_sample=prev_sample)
671
+
672
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
673
+ def add_noise(
674
+ self,
675
+ original_samples: torch.Tensor,
676
+ noise: torch.Tensor,
677
+ timesteps: torch.Tensor,
678
+ ) -> torch.Tensor:
679
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
680
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
681
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
682
+ # mps does not support float64
683
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
684
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
685
+ else:
686
+ schedule_timesteps = self.timesteps.to(original_samples.device)
687
+ timesteps = timesteps.to(original_samples.device)
688
+
689
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
690
+ if self.begin_index is None:
691
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
692
+ elif self.step_index is not None:
693
+ # add_noise is called after first denoising step (for inpainting)
694
+ step_indices = [self.step_index] * timesteps.shape[0]
695
+ else:
696
+ # add noise is called before first denoising step to create initial latent(img2img)
697
+ step_indices = [self.begin_index] * timesteps.shape[0]
698
+
699
+ sigma = sigmas[step_indices].flatten()
700
+ while len(sigma.shape) < len(original_samples.shape):
701
+ sigma = sigma.unsqueeze(-1)
702
+
703
+ noisy_samples = original_samples + noise * sigma
704
+ return noisy_samples
705
+
706
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
707
+ def _get_conditioning_c_in(self, sigma):
708
+ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
709
+ return c_in
710
+
711
+ def __len__(self):
712
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_edm_euler.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput, logging
23
+ from ..utils.torch_utils import randn_tensor
24
+ from .scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
32
+ class EDMEulerSchedulerOutput(BaseOutput):
33
+ """
34
+ Output class for the scheduler's `step` function output.
35
+
36
+ Args:
37
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
39
+ denoising loop.
40
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
42
+ `pred_original_sample` can be used to preview progress or for guidance.
43
+ """
44
+
45
+ prev_sample: torch.Tensor
46
+ pred_original_sample: Optional[torch.Tensor] = None
47
+
48
+
49
+ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
50
+ """
51
+ Implements the Euler scheduler in EDM formulation as presented in Karras et al. 2022 [1].
52
+
53
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
54
+ https://huggingface.co/papers/2206.00364
55
+
56
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
57
+ methods the library implements for all schedulers such as loading and saving.
58
+
59
+ Args:
60
+ sigma_min (`float`, *optional*, defaults to 0.002):
61
+ Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable
62
+ range is [0, 10].
63
+ sigma_max (`float`, *optional*, defaults to 80.0):
64
+ Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable
65
+ range is [0.2, 80.0].
66
+ sigma_data (`float`, *optional*, defaults to 0.5):
67
+ The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
68
+ sigma_schedule (`str`, *optional*, defaults to `karras`):
69
+ Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
70
+ (https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
71
+ schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
72
+ num_train_timesteps (`int`, defaults to 1000):
73
+ The number of diffusion steps to train the model.
74
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
75
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
76
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
77
+ Video](https://imagen.research.google/video/paper.pdf) paper).
78
+ rho (`float`, *optional*, defaults to 7.0):
79
+ The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
80
+ final_sigmas_type (`str`, defaults to `"zero"`):
81
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
82
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
83
+ """
84
+
85
+ _compatibles = []
86
+ order = 1
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ sigma_min: float = 0.002,
92
+ sigma_max: float = 80.0,
93
+ sigma_data: float = 0.5,
94
+ sigma_schedule: str = "karras",
95
+ num_train_timesteps: int = 1000,
96
+ prediction_type: str = "epsilon",
97
+ rho: float = 7.0,
98
+ final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
99
+ ):
100
+ if sigma_schedule not in ["karras", "exponential"]:
101
+ raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
102
+
103
+ # setable values
104
+ self.num_inference_steps = None
105
+
106
+ sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
107
+ sigmas = torch.arange(num_train_timesteps + 1, dtype=sigmas_dtype) / num_train_timesteps
108
+ if sigma_schedule == "karras":
109
+ sigmas = self._compute_karras_sigmas(sigmas)
110
+ elif sigma_schedule == "exponential":
111
+ sigmas = self._compute_exponential_sigmas(sigmas)
112
+ sigmas = sigmas.to(torch.float32)
113
+
114
+ self.timesteps = self.precondition_noise(sigmas)
115
+
116
+ if self.config.final_sigmas_type == "sigma_min":
117
+ sigma_last = sigmas[-1]
118
+ elif self.config.final_sigmas_type == "zero":
119
+ sigma_last = 0
120
+ else:
121
+ raise ValueError(
122
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
123
+ )
124
+
125
+ self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
126
+
127
+ self.is_scale_input_called = False
128
+
129
+ self._step_index = None
130
+ self._begin_index = None
131
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
132
+
133
+ @property
134
+ def init_noise_sigma(self):
135
+ # standard deviation of the initial noise distribution
136
+ return (self.config.sigma_max**2 + 1) ** 0.5
137
+
138
+ @property
139
+ def step_index(self):
140
+ """
141
+ The index counter for current timestep. It will increase 1 after each scheduler step.
142
+ """
143
+ return self._step_index
144
+
145
+ @property
146
+ def begin_index(self):
147
+ """
148
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
149
+ """
150
+ return self._begin_index
151
+
152
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
153
+ def set_begin_index(self, begin_index: int = 0):
154
+ """
155
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
156
+
157
+ Args:
158
+ begin_index (`int`):
159
+ The begin index for the scheduler.
160
+ """
161
+ self._begin_index = begin_index
162
+
163
+ def precondition_inputs(self, sample, sigma):
164
+ c_in = self._get_conditioning_c_in(sigma)
165
+ scaled_sample = sample * c_in
166
+ return scaled_sample
167
+
168
+ def precondition_noise(self, sigma):
169
+ if not isinstance(sigma, torch.Tensor):
170
+ sigma = torch.tensor([sigma])
171
+
172
+ c_noise = 0.25 * torch.log(sigma)
173
+
174
+ return c_noise
175
+
176
+ def precondition_outputs(self, sample, model_output, sigma):
177
+ sigma_data = self.config.sigma_data
178
+ c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
179
+
180
+ if self.config.prediction_type == "epsilon":
181
+ c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
182
+ elif self.config.prediction_type == "v_prediction":
183
+ c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
184
+ else:
185
+ raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
186
+
187
+ denoised = c_skip * sample + c_out * model_output
188
+
189
+ return denoised
190
+
191
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
192
+ """
193
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
194
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
195
+
196
+ Args:
197
+ sample (`torch.Tensor`):
198
+ The input sample.
199
+ timestep (`int`, *optional*):
200
+ The current timestep in the diffusion chain.
201
+
202
+ Returns:
203
+ `torch.Tensor`:
204
+ A scaled input sample.
205
+ """
206
+ if self.step_index is None:
207
+ self._init_step_index(timestep)
208
+
209
+ sigma = self.sigmas[self.step_index]
210
+ sample = self.precondition_inputs(sample, sigma)
211
+
212
+ self.is_scale_input_called = True
213
+ return sample
214
+
215
+ def set_timesteps(
216
+ self,
217
+ num_inference_steps: int = None,
218
+ device: Union[str, torch.device] = None,
219
+ sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
220
+ ):
221
+ """
222
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
223
+
224
+ Args:
225
+ num_inference_steps (`int`):
226
+ The number of diffusion steps used when generating samples with a pre-trained model.
227
+ device (`str` or `torch.device`, *optional*):
228
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
229
+ sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
230
+ Custom sigmas to use for the denoising process. If not defined, the default behavior when
231
+ `num_inference_steps` is passed will be used.
232
+ """
233
+ self.num_inference_steps = num_inference_steps
234
+
235
+ sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
236
+ if sigmas is None:
237
+ sigmas = torch.linspace(0, 1, self.num_inference_steps, dtype=sigmas_dtype)
238
+ elif isinstance(sigmas, float):
239
+ sigmas = torch.tensor(sigmas, dtype=sigmas_dtype)
240
+ else:
241
+ sigmas = sigmas.to(sigmas_dtype)
242
+ if self.config.sigma_schedule == "karras":
243
+ sigmas = self._compute_karras_sigmas(sigmas)
244
+ elif self.config.sigma_schedule == "exponential":
245
+ sigmas = self._compute_exponential_sigmas(sigmas)
246
+ sigmas = sigmas.to(dtype=torch.float32, device=device)
247
+
248
+ self.timesteps = self.precondition_noise(sigmas)
249
+
250
+ if self.config.final_sigmas_type == "sigma_min":
251
+ sigma_last = sigmas[-1]
252
+ elif self.config.final_sigmas_type == "zero":
253
+ sigma_last = 0
254
+ else:
255
+ raise ValueError(
256
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
257
+ )
258
+
259
+ self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
260
+ self._step_index = None
261
+ self._begin_index = None
262
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
263
+
264
+ # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
265
+ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
266
+ """Constructs the noise schedule of Karras et al. (2022)."""
267
+ sigma_min = sigma_min or self.config.sigma_min
268
+ sigma_max = sigma_max or self.config.sigma_max
269
+
270
+ rho = self.config.rho
271
+ min_inv_rho = sigma_min ** (1 / rho)
272
+ max_inv_rho = sigma_max ** (1 / rho)
273
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
274
+ return sigmas
275
+
276
+ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
277
+ """Implementation closely follows k-diffusion.
278
+
279
+ https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
280
+ """
281
+ sigma_min = sigma_min or self.config.sigma_min
282
+ sigma_max = sigma_max or self.config.sigma_max
283
+ sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
284
+ return sigmas
285
+
286
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
287
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
288
+ if schedule_timesteps is None:
289
+ schedule_timesteps = self.timesteps
290
+
291
+ indices = (schedule_timesteps == timestep).nonzero()
292
+
293
+ # The sigma index that is taken for the **very** first `step`
294
+ # is always the second index (or the last index if there is only 1)
295
+ # This way we can ensure we don't accidentally skip a sigma in
296
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
297
+ pos = 1 if len(indices) > 1 else 0
298
+
299
+ return indices[pos].item()
300
+
301
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
302
+ def _init_step_index(self, timestep):
303
+ if self.begin_index is None:
304
+ if isinstance(timestep, torch.Tensor):
305
+ timestep = timestep.to(self.timesteps.device)
306
+ self._step_index = self.index_for_timestep(timestep)
307
+ else:
308
+ self._step_index = self._begin_index
309
+
310
+ def step(
311
+ self,
312
+ model_output: torch.Tensor,
313
+ timestep: Union[float, torch.Tensor],
314
+ sample: torch.Tensor,
315
+ s_churn: float = 0.0,
316
+ s_tmin: float = 0.0,
317
+ s_tmax: float = float("inf"),
318
+ s_noise: float = 1.0,
319
+ generator: Optional[torch.Generator] = None,
320
+ return_dict: bool = True,
321
+ pred_original_sample: Optional[torch.Tensor] = None,
322
+ ) -> Union[EDMEulerSchedulerOutput, Tuple]:
323
+ """
324
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
325
+ process from the learned model outputs (most often the predicted noise).
326
+
327
+ Args:
328
+ model_output (`torch.Tensor`):
329
+ The direct output from learned diffusion model.
330
+ timestep (`float`):
331
+ The current discrete timestep in the diffusion chain.
332
+ sample (`torch.Tensor`):
333
+ A current instance of a sample created by the diffusion process.
334
+ s_churn (`float`):
335
+ s_tmin (`float`):
336
+ s_tmax (`float`):
337
+ s_noise (`float`, defaults to 1.0):
338
+ Scaling factor for noise added to the sample.
339
+ generator (`torch.Generator`, *optional*):
340
+ A random number generator.
341
+ return_dict (`bool`):
342
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple.
343
+
344
+ Returns:
345
+ [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`:
346
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is
347
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
348
+ """
349
+
350
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
351
+ raise ValueError(
352
+ (
353
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
354
+ " `EDMEulerScheduler.step()` is not supported. Make sure to pass"
355
+ " one of the `scheduler.timesteps` as a timestep."
356
+ ),
357
+ )
358
+
359
+ if not self.is_scale_input_called:
360
+ logger.warning(
361
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
362
+ "See `StableDiffusionPipeline` for a usage example."
363
+ )
364
+
365
+ if self.step_index is None:
366
+ self._init_step_index(timestep)
367
+
368
+ # Upcast to avoid precision issues when computing prev_sample
369
+ sample = sample.to(torch.float32)
370
+
371
+ sigma = self.sigmas[self.step_index]
372
+
373
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
374
+
375
+ sigma_hat = sigma * (gamma + 1)
376
+
377
+ if gamma > 0:
378
+ noise = randn_tensor(
379
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
380
+ )
381
+ eps = noise * s_noise
382
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
383
+
384
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
385
+ if pred_original_sample is None:
386
+ pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
387
+
388
+ # 2. Convert to an ODE derivative
389
+ derivative = (sample - pred_original_sample) / sigma_hat
390
+
391
+ dt = self.sigmas[self.step_index + 1] - sigma_hat
392
+
393
+ prev_sample = sample + derivative * dt
394
+
395
+ # Cast sample back to model compatible dtype
396
+ prev_sample = prev_sample.to(model_output.dtype)
397
+
398
+ # upon completion increase step index by one
399
+ self._step_index += 1
400
+
401
+ if not return_dict:
402
+ return (
403
+ prev_sample,
404
+ pred_original_sample,
405
+ )
406
+
407
+ return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
408
+
409
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
410
+ def add_noise(
411
+ self,
412
+ original_samples: torch.Tensor,
413
+ noise: torch.Tensor,
414
+ timesteps: torch.Tensor,
415
+ ) -> torch.Tensor:
416
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
417
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
418
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
419
+ # mps does not support float64
420
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
421
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
422
+ else:
423
+ schedule_timesteps = self.timesteps.to(original_samples.device)
424
+ timesteps = timesteps.to(original_samples.device)
425
+
426
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
427
+ if self.begin_index is None:
428
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
429
+ elif self.step_index is not None:
430
+ # add_noise is called after first denoising step (for inpainting)
431
+ step_indices = [self.step_index] * timesteps.shape[0]
432
+ else:
433
+ # add noise is called before first denoising step to create initial latent(img2img)
434
+ step_indices = [self.begin_index] * timesteps.shape[0]
435
+
436
+ sigma = sigmas[step_indices].flatten()
437
+ while len(sigma.shape) < len(original_samples.shape):
438
+ sigma = sigma.unsqueeze(-1)
439
+
440
+ noisy_samples = original_samples + noise * sigma
441
+ return noisy_samples
442
+
443
+ def _get_conditioning_c_in(self, sigma):
444
+ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
445
+ return c_in
446
+
447
+ def __len__(self):
448
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_ancestral_discrete.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils import BaseOutput, logging
24
+ from ..utils.torch_utils import randn_tensor
25
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ @dataclass
32
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
33
+ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
34
+ """
35
+ Output class for the scheduler's `step` function output.
36
+
37
+ Args:
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
+ denoising loop.
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
+ `pred_original_sample` can be used to preview progress or for guidance.
44
+ """
45
+
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
48
+
49
+
50
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
51
+ def betas_for_alpha_bar(
52
+ num_diffusion_timesteps,
53
+ max_beta=0.999,
54
+ alpha_transform_type="cosine",
55
+ ):
56
+ """
57
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
58
+ (1-beta) over time from t = [0,1].
59
+
60
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
61
+ to that part of the diffusion process.
62
+
63
+
64
+ Args:
65
+ num_diffusion_timesteps (`int`): the number of betas to produce.
66
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
67
+ prevent singularities.
68
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
69
+ Choose from `cosine` or `exp`
70
+
71
+ Returns:
72
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
73
+ """
74
+ if alpha_transform_type == "cosine":
75
+
76
+ def alpha_bar_fn(t):
77
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
78
+
79
+ elif alpha_transform_type == "exp":
80
+
81
+ def alpha_bar_fn(t):
82
+ return math.exp(t * -12.0)
83
+
84
+ else:
85
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
86
+
87
+ betas = []
88
+ for i in range(num_diffusion_timesteps):
89
+ t1 = i / num_diffusion_timesteps
90
+ t2 = (i + 1) / num_diffusion_timesteps
91
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
92
+ return torch.tensor(betas, dtype=torch.float32)
93
+
94
+
95
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
96
+ def rescale_zero_terminal_snr(betas):
97
+ """
98
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
99
+
100
+
101
+ Args:
102
+ betas (`torch.Tensor`):
103
+ the betas that the scheduler is being initialized with.
104
+
105
+ Returns:
106
+ `torch.Tensor`: rescaled betas with zero terminal SNR
107
+ """
108
+ # Convert betas to alphas_bar_sqrt
109
+ alphas = 1.0 - betas
110
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
111
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
112
+
113
+ # Store old values.
114
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
115
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
116
+
117
+ # Shift so the last timestep is zero.
118
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
119
+
120
+ # Scale so the first timestep is back to the old value.
121
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
122
+
123
+ # Convert alphas_bar_sqrt to betas
124
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
125
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
126
+ alphas = torch.cat([alphas_bar[0:1], alphas])
127
+ betas = 1 - alphas
128
+
129
+ return betas
130
+
131
+
132
+ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
133
+ """
134
+ Ancestral sampling with Euler method steps.
135
+
136
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
137
+ methods the library implements for all schedulers such as loading and saving.
138
+
139
+ Args:
140
+ num_train_timesteps (`int`, defaults to 1000):
141
+ The number of diffusion steps to train the model.
142
+ beta_start (`float`, defaults to 0.0001):
143
+ The starting `beta` value of inference.
144
+ beta_end (`float`, defaults to 0.02):
145
+ The final `beta` value.
146
+ beta_schedule (`str`, defaults to `"linear"`):
147
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
148
+ `linear` or `scaled_linear`.
149
+ trained_betas (`np.ndarray`, *optional*):
150
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
151
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
152
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
153
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
154
+ Video](https://imagen.research.google/video/paper.pdf) paper).
155
+ timestep_spacing (`str`, defaults to `"linspace"`):
156
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
157
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
158
+ steps_offset (`int`, defaults to 0):
159
+ An offset added to the inference steps, as required by some model families.
160
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
161
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
162
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
163
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
164
+ """
165
+
166
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
167
+ order = 1
168
+
169
+ @register_to_config
170
+ def __init__(
171
+ self,
172
+ num_train_timesteps: int = 1000,
173
+ beta_start: float = 0.0001,
174
+ beta_end: float = 0.02,
175
+ beta_schedule: str = "linear",
176
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
177
+ prediction_type: str = "epsilon",
178
+ timestep_spacing: str = "linspace",
179
+ steps_offset: int = 0,
180
+ rescale_betas_zero_snr: bool = False,
181
+ ):
182
+ if trained_betas is not None:
183
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
184
+ elif beta_schedule == "linear":
185
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
186
+ elif beta_schedule == "scaled_linear":
187
+ # this schedule is very specific to the latent diffusion model.
188
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
189
+ elif beta_schedule == "squaredcos_cap_v2":
190
+ # Glide cosine schedule
191
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
192
+ else:
193
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
194
+
195
+ if rescale_betas_zero_snr:
196
+ self.betas = rescale_zero_terminal_snr(self.betas)
197
+
198
+ self.alphas = 1.0 - self.betas
199
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
200
+
201
+ if rescale_betas_zero_snr:
202
+ # Close to 0 without being 0 so first sigma is not inf
203
+ # FP16 smallest positive subnormal works well here
204
+ self.alphas_cumprod[-1] = 2**-24
205
+
206
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
207
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
208
+ self.sigmas = torch.from_numpy(sigmas)
209
+
210
+ # setable values
211
+ self.num_inference_steps = None
212
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
213
+ self.timesteps = torch.from_numpy(timesteps)
214
+ self.is_scale_input_called = False
215
+
216
+ self._step_index = None
217
+ self._begin_index = None
218
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
219
+
220
+ @property
221
+ def init_noise_sigma(self):
222
+ # standard deviation of the initial noise distribution
223
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
224
+ return self.sigmas.max()
225
+
226
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
227
+
228
+ @property
229
+ def step_index(self):
230
+ """
231
+ The index counter for current timestep. It will increase 1 after each scheduler step.
232
+ """
233
+ return self._step_index
234
+
235
+ @property
236
+ def begin_index(self):
237
+ """
238
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
239
+ """
240
+ return self._begin_index
241
+
242
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
243
+ def set_begin_index(self, begin_index: int = 0):
244
+ """
245
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
246
+
247
+ Args:
248
+ begin_index (`int`):
249
+ The begin index for the scheduler.
250
+ """
251
+ self._begin_index = begin_index
252
+
253
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
254
+ """
255
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
256
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
257
+
258
+ Args:
259
+ sample (`torch.Tensor`):
260
+ The input sample.
261
+ timestep (`int`, *optional*):
262
+ The current timestep in the diffusion chain.
263
+
264
+ Returns:
265
+ `torch.Tensor`:
266
+ A scaled input sample.
267
+ """
268
+
269
+ if self.step_index is None:
270
+ self._init_step_index(timestep)
271
+
272
+ sigma = self.sigmas[self.step_index]
273
+ sample = sample / ((sigma**2 + 1) ** 0.5)
274
+ self.is_scale_input_called = True
275
+ return sample
276
+
277
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
278
+ """
279
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
280
+
281
+ Args:
282
+ num_inference_steps (`int`):
283
+ The number of diffusion steps used when generating samples with a pre-trained model.
284
+ device (`str` or `torch.device`, *optional*):
285
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
286
+ """
287
+ self.num_inference_steps = num_inference_steps
288
+
289
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
290
+ if self.config.timestep_spacing == "linspace":
291
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
292
+ ::-1
293
+ ].copy()
294
+ elif self.config.timestep_spacing == "leading":
295
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
296
+ # creates integer timesteps by multiplying by ratio
297
+ # casting to int to avoid issues when num_inference_step is power of 3
298
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
299
+ timesteps += self.config.steps_offset
300
+ elif self.config.timestep_spacing == "trailing":
301
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
302
+ # creates integer timesteps by multiplying by ratio
303
+ # casting to int to avoid issues when num_inference_step is power of 3
304
+ timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
305
+ timesteps -= 1
306
+ else:
307
+ raise ValueError(
308
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
309
+ )
310
+
311
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
312
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
313
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
314
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
315
+
316
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
317
+ self._step_index = None
318
+ self._begin_index = None
319
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
320
+
321
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
322
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
323
+ if schedule_timesteps is None:
324
+ schedule_timesteps = self.timesteps
325
+
326
+ indices = (schedule_timesteps == timestep).nonzero()
327
+
328
+ # The sigma index that is taken for the **very** first `step`
329
+ # is always the second index (or the last index if there is only 1)
330
+ # This way we can ensure we don't accidentally skip a sigma in
331
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
332
+ pos = 1 if len(indices) > 1 else 0
333
+
334
+ return indices[pos].item()
335
+
336
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
337
+ def _init_step_index(self, timestep):
338
+ if self.begin_index is None:
339
+ if isinstance(timestep, torch.Tensor):
340
+ timestep = timestep.to(self.timesteps.device)
341
+ self._step_index = self.index_for_timestep(timestep)
342
+ else:
343
+ self._step_index = self._begin_index
344
+
345
+ def step(
346
+ self,
347
+ model_output: torch.Tensor,
348
+ timestep: Union[float, torch.Tensor],
349
+ sample: torch.Tensor,
350
+ generator: Optional[torch.Generator] = None,
351
+ return_dict: bool = True,
352
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
353
+ """
354
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
355
+ process from the learned model outputs (most often the predicted noise).
356
+
357
+ Args:
358
+ model_output (`torch.Tensor`):
359
+ The direct output from learned diffusion model.
360
+ timestep (`float`):
361
+ The current discrete timestep in the diffusion chain.
362
+ sample (`torch.Tensor`):
363
+ A current instance of a sample created by the diffusion process.
364
+ generator (`torch.Generator`, *optional*):
365
+ A random number generator.
366
+ return_dict (`bool`):
367
+ Whether or not to return a
368
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
369
+
370
+ Returns:
371
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
372
+ If return_dict is `True`,
373
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
374
+ otherwise a tuple is returned where the first element is the sample tensor.
375
+
376
+ """
377
+
378
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
379
+ raise ValueError(
380
+ (
381
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
382
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
383
+ " one of the `scheduler.timesteps` as a timestep."
384
+ ),
385
+ )
386
+
387
+ if not self.is_scale_input_called:
388
+ logger.warning(
389
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
390
+ "See `StableDiffusionPipeline` for a usage example."
391
+ )
392
+
393
+ if self.step_index is None:
394
+ self._init_step_index(timestep)
395
+
396
+ sigma = self.sigmas[self.step_index]
397
+
398
+ # Upcast to avoid precision issues when computing prev_sample
399
+ sample = sample.to(torch.float32)
400
+
401
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
402
+ if self.config.prediction_type == "epsilon":
403
+ pred_original_sample = sample - sigma * model_output
404
+ elif self.config.prediction_type == "v_prediction":
405
+ # * c_out + input * c_skip
406
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
407
+ elif self.config.prediction_type == "sample":
408
+ raise NotImplementedError("prediction_type not implemented yet: sample")
409
+ else:
410
+ raise ValueError(
411
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
412
+ )
413
+
414
+ sigma_from = self.sigmas[self.step_index]
415
+ sigma_to = self.sigmas[self.step_index + 1]
416
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
417
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
418
+
419
+ # 2. Convert to an ODE derivative
420
+ derivative = (sample - pred_original_sample) / sigma
421
+
422
+ dt = sigma_down - sigma
423
+
424
+ prev_sample = sample + derivative * dt
425
+
426
+ device = model_output.device
427
+ noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
428
+
429
+ prev_sample = prev_sample + noise * sigma_up
430
+
431
+ # Cast sample back to model compatible dtype
432
+ prev_sample = prev_sample.to(model_output.dtype)
433
+
434
+ # upon completion increase step index by one
435
+ self._step_index += 1
436
+
437
+ if not return_dict:
438
+ return (
439
+ prev_sample,
440
+ pred_original_sample,
441
+ )
442
+
443
+ return EulerAncestralDiscreteSchedulerOutput(
444
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
445
+ )
446
+
447
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
448
+ def add_noise(
449
+ self,
450
+ original_samples: torch.Tensor,
451
+ noise: torch.Tensor,
452
+ timesteps: torch.Tensor,
453
+ ) -> torch.Tensor:
454
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
455
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
456
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
457
+ # mps does not support float64
458
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
459
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
460
+ else:
461
+ schedule_timesteps = self.timesteps.to(original_samples.device)
462
+ timesteps = timesteps.to(original_samples.device)
463
+
464
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
465
+ if self.begin_index is None:
466
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
467
+ elif self.step_index is not None:
468
+ # add_noise is called after first denoising step (for inpainting)
469
+ step_indices = [self.step_index] * timesteps.shape[0]
470
+ else:
471
+ # add noise is called before first denoising step to create initial latent(img2img)
472
+ step_indices = [self.begin_index] * timesteps.shape[0]
473
+
474
+ sigma = sigmas[step_indices].flatten()
475
+ while len(sigma.shape) < len(original_samples.shape):
476
+ sigma = sigma.unsqueeze(-1)
477
+
478
+ noisy_samples = original_samples + noise * sigma
479
+ return noisy_samples
480
+
481
+ def __len__(self):
482
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_discrete.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils import BaseOutput, is_scipy_available, logging
24
+ from ..utils.torch_utils import randn_tensor
25
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
26
+
27
+
28
+ if is_scipy_available():
29
+ import scipy.stats
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
36
+ class EulerDiscreteSchedulerOutput(BaseOutput):
37
+ """
38
+ Output class for the scheduler's `step` function output.
39
+
40
+ Args:
41
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
43
+ denoising loop.
44
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
45
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
46
+ `pred_original_sample` can be used to preview progress or for guidance.
47
+ """
48
+
49
+ prev_sample: torch.Tensor
50
+ pred_original_sample: Optional[torch.Tensor] = None
51
+
52
+
53
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
54
+ def betas_for_alpha_bar(
55
+ num_diffusion_timesteps,
56
+ max_beta=0.999,
57
+ alpha_transform_type="cosine",
58
+ ):
59
+ """
60
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
61
+ (1-beta) over time from t = [0,1].
62
+
63
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
64
+ to that part of the diffusion process.
65
+
66
+
67
+ Args:
68
+ num_diffusion_timesteps (`int`): the number of betas to produce.
69
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
70
+ prevent singularities.
71
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
72
+ Choose from `cosine` or `exp`
73
+
74
+ Returns:
75
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
76
+ """
77
+ if alpha_transform_type == "cosine":
78
+
79
+ def alpha_bar_fn(t):
80
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
81
+
82
+ elif alpha_transform_type == "exp":
83
+
84
+ def alpha_bar_fn(t):
85
+ return math.exp(t * -12.0)
86
+
87
+ else:
88
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
89
+
90
+ betas = []
91
+ for i in range(num_diffusion_timesteps):
92
+ t1 = i / num_diffusion_timesteps
93
+ t2 = (i + 1) / num_diffusion_timesteps
94
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
95
+ return torch.tensor(betas, dtype=torch.float32)
96
+
97
+
98
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
99
+ def rescale_zero_terminal_snr(betas):
100
+ """
101
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
102
+
103
+
104
+ Args:
105
+ betas (`torch.Tensor`):
106
+ the betas that the scheduler is being initialized with.
107
+
108
+ Returns:
109
+ `torch.Tensor`: rescaled betas with zero terminal SNR
110
+ """
111
+ # Convert betas to alphas_bar_sqrt
112
+ alphas = 1.0 - betas
113
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
114
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
115
+
116
+ # Store old values.
117
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
118
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
119
+
120
+ # Shift so the last timestep is zero.
121
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
122
+
123
+ # Scale so the first timestep is back to the old value.
124
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
125
+
126
+ # Convert alphas_bar_sqrt to betas
127
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
128
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
129
+ alphas = torch.cat([alphas_bar[0:1], alphas])
130
+ betas = 1 - alphas
131
+
132
+ return betas
133
+
134
+
135
+ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
136
+ """
137
+ Euler scheduler.
138
+
139
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
140
+ methods the library implements for all schedulers such as loading and saving.
141
+
142
+ Args:
143
+ num_train_timesteps (`int`, defaults to 1000):
144
+ The number of diffusion steps to train the model.
145
+ beta_start (`float`, defaults to 0.0001):
146
+ The starting `beta` value of inference.
147
+ beta_end (`float`, defaults to 0.02):
148
+ The final `beta` value.
149
+ beta_schedule (`str`, defaults to `"linear"`):
150
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
151
+ `linear` or `scaled_linear`.
152
+ trained_betas (`np.ndarray`, *optional*):
153
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
154
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
155
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
156
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
157
+ Video](https://imagen.research.google/video/paper.pdf) paper).
158
+ interpolation_type(`str`, defaults to `"linear"`, *optional*):
159
+ The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
160
+ `"linear"` or `"log_linear"`.
161
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
162
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
163
+ the sigmas are determined according to a sequence of noise levels {σi}.
164
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
165
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
166
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
167
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
168
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
169
+ timestep_spacing (`str`, defaults to `"linspace"`):
170
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
171
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
172
+ steps_offset (`int`, defaults to 0):
173
+ An offset added to the inference steps, as required by some model families.
174
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
175
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
176
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
177
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
178
+ final_sigmas_type (`str`, defaults to `"zero"`):
179
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
180
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
181
+ """
182
+
183
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
184
+ order = 1
185
+
186
+ @register_to_config
187
+ def __init__(
188
+ self,
189
+ num_train_timesteps: int = 1000,
190
+ beta_start: float = 0.0001,
191
+ beta_end: float = 0.02,
192
+ beta_schedule: str = "linear",
193
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
194
+ prediction_type: str = "epsilon",
195
+ interpolation_type: str = "linear",
196
+ use_karras_sigmas: Optional[bool] = False,
197
+ use_exponential_sigmas: Optional[bool] = False,
198
+ use_beta_sigmas: Optional[bool] = False,
199
+ sigma_min: Optional[float] = None,
200
+ sigma_max: Optional[float] = None,
201
+ timestep_spacing: str = "linspace",
202
+ timestep_type: str = "discrete", # can be "discrete" or "continuous"
203
+ steps_offset: int = 0,
204
+ rescale_betas_zero_snr: bool = False,
205
+ final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
206
+ ):
207
+ if self.config.use_beta_sigmas and not is_scipy_available():
208
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
209
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
210
+ raise ValueError(
211
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
212
+ )
213
+ if trained_betas is not None:
214
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
215
+ elif beta_schedule == "linear":
216
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
217
+ elif beta_schedule == "scaled_linear":
218
+ # this schedule is very specific to the latent diffusion model.
219
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
220
+ elif beta_schedule == "squaredcos_cap_v2":
221
+ # Glide cosine schedule
222
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
223
+ else:
224
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
225
+
226
+ if rescale_betas_zero_snr:
227
+ self.betas = rescale_zero_terminal_snr(self.betas)
228
+
229
+ self.alphas = 1.0 - self.betas
230
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
231
+
232
+ if rescale_betas_zero_snr:
233
+ # Close to 0 without being 0 so first sigma is not inf
234
+ # FP16 smallest positive subnormal works well here
235
+ self.alphas_cumprod[-1] = 2**-24
236
+
237
+ sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)
238
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
239
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
240
+
241
+ # setable values
242
+ self.num_inference_steps = None
243
+
244
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
245
+ if timestep_type == "continuous" and prediction_type == "v_prediction":
246
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
247
+ else:
248
+ self.timesteps = timesteps
249
+
250
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
251
+
252
+ self.is_scale_input_called = False
253
+ self.use_karras_sigmas = use_karras_sigmas
254
+ self.use_exponential_sigmas = use_exponential_sigmas
255
+ self.use_beta_sigmas = use_beta_sigmas
256
+
257
+ self._step_index = None
258
+ self._begin_index = None
259
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
260
+
261
+ @property
262
+ def init_noise_sigma(self):
263
+ # standard deviation of the initial noise distribution
264
+ max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
265
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
266
+ return max_sigma
267
+
268
+ return (max_sigma**2 + 1) ** 0.5
269
+
270
+ @property
271
+ def step_index(self):
272
+ """
273
+ The index counter for current timestep. It will increase 1 after each scheduler step.
274
+ """
275
+ return self._step_index
276
+
277
+ @property
278
+ def begin_index(self):
279
+ """
280
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
281
+ """
282
+ return self._begin_index
283
+
284
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
285
+ def set_begin_index(self, begin_index: int = 0):
286
+ """
287
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
288
+
289
+ Args:
290
+ begin_index (`int`):
291
+ The begin index for the scheduler.
292
+ """
293
+ self._begin_index = begin_index
294
+
295
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
296
+ """
297
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
298
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
299
+
300
+ Args:
301
+ sample (`torch.Tensor`):
302
+ The input sample.
303
+ timestep (`int`, *optional*):
304
+ The current timestep in the diffusion chain.
305
+
306
+ Returns:
307
+ `torch.Tensor`:
308
+ A scaled input sample.
309
+ """
310
+ if self.step_index is None:
311
+ self._init_step_index(timestep)
312
+
313
+ sigma = self.sigmas[self.step_index]
314
+ sample = sample / ((sigma**2 + 1) ** 0.5)
315
+
316
+ self.is_scale_input_called = True
317
+ return sample
318
+
319
+ def set_timesteps(
320
+ self,
321
+ num_inference_steps: int = None,
322
+ device: Union[str, torch.device] = None,
323
+ timesteps: Optional[List[int]] = None,
324
+ sigmas: Optional[List[float]] = None,
325
+ ):
326
+ """
327
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
328
+
329
+ Args:
330
+ num_inference_steps (`int`):
331
+ The number of diffusion steps used when generating samples with a pre-trained model.
332
+ device (`str` or `torch.device`, *optional*):
333
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
334
+ timesteps (`List[int]`, *optional*):
335
+ Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
336
+ based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
337
+ must be `None`, and `timestep_spacing` attribute will be ignored.
338
+ sigmas (`List[float]`, *optional*):
339
+ Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas
340
+ will be generated based on the relevant scheduler attributes. If `sigmas` is passed,
341
+ `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the
342
+ custom sigmas schedule.
343
+ """
344
+
345
+ if timesteps is not None and sigmas is not None:
346
+ raise ValueError("Only one of `timesteps` or `sigmas` should be set.")
347
+ if num_inference_steps is None and timesteps is None and sigmas is None:
348
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.")
349
+ if num_inference_steps is not None and (timesteps is not None or sigmas is not None):
350
+ raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
351
+ if timesteps is not None and self.config.use_karras_sigmas:
352
+ raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
353
+ if timesteps is not None and self.config.use_exponential_sigmas:
354
+ raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
355
+ if timesteps is not None and self.config.use_beta_sigmas:
356
+ raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
357
+ if (
358
+ timesteps is not None
359
+ and self.config.timestep_type == "continuous"
360
+ and self.config.prediction_type == "v_prediction"
361
+ ):
362
+ raise ValueError(
363
+ "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
364
+ )
365
+
366
+ if num_inference_steps is None:
367
+ num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
368
+ self.num_inference_steps = num_inference_steps
369
+
370
+ if sigmas is not None:
371
+ log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
372
+ sigmas = np.array(sigmas).astype(np.float32)
373
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
374
+
375
+ else:
376
+ if timesteps is not None:
377
+ timesteps = np.array(timesteps).astype(np.float32)
378
+ else:
379
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
380
+ if self.config.timestep_spacing == "linspace":
381
+ timesteps = np.linspace(
382
+ 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
383
+ )[::-1].copy()
384
+ elif self.config.timestep_spacing == "leading":
385
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
386
+ # creates integer timesteps by multiplying by ratio
387
+ # casting to int to avoid issues when num_inference_step is power of 3
388
+ timesteps = (
389
+ (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
390
+ )
391
+ timesteps += self.config.steps_offset
392
+ elif self.config.timestep_spacing == "trailing":
393
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
394
+ # creates integer timesteps by multiplying by ratio
395
+ # casting to int to avoid issues when num_inference_step is power of 3
396
+ timesteps = (
397
+ (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
398
+ )
399
+ timesteps -= 1
400
+ else:
401
+ raise ValueError(
402
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
403
+ )
404
+
405
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
406
+ log_sigmas = np.log(sigmas)
407
+ if self.config.interpolation_type == "linear":
408
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
409
+ elif self.config.interpolation_type == "log_linear":
410
+ sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
411
+ else:
412
+ raise ValueError(
413
+ f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
414
+ " 'linear' or 'log_linear'"
415
+ )
416
+
417
+ if self.config.use_karras_sigmas:
418
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
419
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
420
+
421
+ elif self.config.use_exponential_sigmas:
422
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
423
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
424
+
425
+ elif self.config.use_beta_sigmas:
426
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
427
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
428
+
429
+ if self.config.final_sigmas_type == "sigma_min":
430
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
431
+ elif self.config.final_sigmas_type == "zero":
432
+ sigma_last = 0
433
+ else:
434
+ raise ValueError(
435
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
436
+ )
437
+
438
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
439
+
440
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
441
+
442
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
443
+ if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
444
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
445
+ else:
446
+ self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
447
+
448
+ self._step_index = None
449
+ self._begin_index = None
450
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
451
+
452
+ def _sigma_to_t(self, sigma, log_sigmas):
453
+ # get log sigma
454
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
455
+
456
+ # get distribution
457
+ dists = log_sigma - log_sigmas[:, np.newaxis]
458
+
459
+ # get sigmas range
460
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
461
+ high_idx = low_idx + 1
462
+
463
+ low = log_sigmas[low_idx]
464
+ high = log_sigmas[high_idx]
465
+
466
+ # interpolate sigmas
467
+ w = (low - log_sigma) / (low - high)
468
+ w = np.clip(w, 0, 1)
469
+
470
+ # transform interpolation to time range
471
+ t = (1 - w) * low_idx + w * high_idx
472
+ t = t.reshape(sigma.shape)
473
+ return t
474
+
475
+ # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
476
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
477
+ """Constructs the noise schedule of Karras et al. (2022)."""
478
+
479
+ # Hack to make sure that other schedulers which copy this function don't break
480
+ # TODO: Add this logic to the other schedulers
481
+ if hasattr(self.config, "sigma_min"):
482
+ sigma_min = self.config.sigma_min
483
+ else:
484
+ sigma_min = None
485
+
486
+ if hasattr(self.config, "sigma_max"):
487
+ sigma_max = self.config.sigma_max
488
+ else:
489
+ sigma_max = None
490
+
491
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
492
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
493
+
494
+ rho = 7.0 # 7.0 is the value used in the paper
495
+ ramp = np.linspace(0, 1, num_inference_steps)
496
+ min_inv_rho = sigma_min ** (1 / rho)
497
+ max_inv_rho = sigma_max ** (1 / rho)
498
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
499
+ return sigmas
500
+
501
+ # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
502
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
503
+ """Constructs an exponential noise schedule."""
504
+
505
+ # Hack to make sure that other schedulers which copy this function don't break
506
+ # TODO: Add this logic to the other schedulers
507
+ if hasattr(self.config, "sigma_min"):
508
+ sigma_min = self.config.sigma_min
509
+ else:
510
+ sigma_min = None
511
+
512
+ if hasattr(self.config, "sigma_max"):
513
+ sigma_max = self.config.sigma_max
514
+ else:
515
+ sigma_max = None
516
+
517
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
518
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
519
+
520
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
521
+ return sigmas
522
+
523
+ def _convert_to_beta(
524
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
525
+ ) -> torch.Tensor:
526
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
527
+
528
+ # Hack to make sure that other schedulers which copy this function don't break
529
+ # TODO: Add this logic to the other schedulers
530
+ if hasattr(self.config, "sigma_min"):
531
+ sigma_min = self.config.sigma_min
532
+ else:
533
+ sigma_min = None
534
+
535
+ if hasattr(self.config, "sigma_max"):
536
+ sigma_max = self.config.sigma_max
537
+ else:
538
+ sigma_max = None
539
+
540
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
541
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
542
+
543
+ sigmas = np.array(
544
+ [
545
+ sigma_min + (ppf * (sigma_max - sigma_min))
546
+ for ppf in [
547
+ scipy.stats.beta.ppf(timestep, alpha, beta)
548
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
549
+ ]
550
+ ]
551
+ )
552
+ return sigmas
553
+
554
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
555
+ if schedule_timesteps is None:
556
+ schedule_timesteps = self.timesteps
557
+
558
+ indices = (schedule_timesteps == timestep).nonzero()
559
+
560
+ # The sigma index that is taken for the **very** first `step`
561
+ # is always the second index (or the last index if there is only 1)
562
+ # This way we can ensure we don't accidentally skip a sigma in
563
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
564
+ pos = 1 if len(indices) > 1 else 0
565
+
566
+ return indices[pos].item()
567
+
568
+ def _init_step_index(self, timestep):
569
+ if self.begin_index is None:
570
+ if isinstance(timestep, torch.Tensor):
571
+ timestep = timestep.to(self.timesteps.device)
572
+ self._step_index = self.index_for_timestep(timestep)
573
+ else:
574
+ self._step_index = self._begin_index
575
+
576
+ def step(
577
+ self,
578
+ model_output: torch.Tensor,
579
+ timestep: Union[float, torch.Tensor],
580
+ sample: torch.Tensor,
581
+ s_churn: float = 0.0,
582
+ s_tmin: float = 0.0,
583
+ s_tmax: float = float("inf"),
584
+ s_noise: float = 1.0,
585
+ generator: Optional[torch.Generator] = None,
586
+ return_dict: bool = True,
587
+ ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
588
+ """
589
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
590
+ process from the learned model outputs (most often the predicted noise).
591
+
592
+ Args:
593
+ model_output (`torch.Tensor`):
594
+ The direct output from learned diffusion model.
595
+ timestep (`float`):
596
+ The current discrete timestep in the diffusion chain.
597
+ sample (`torch.Tensor`):
598
+ A current instance of a sample created by the diffusion process.
599
+ s_churn (`float`):
600
+ s_tmin (`float`):
601
+ s_tmax (`float`):
602
+ s_noise (`float`, defaults to 1.0):
603
+ Scaling factor for noise added to the sample.
604
+ generator (`torch.Generator`, *optional*):
605
+ A random number generator.
606
+ return_dict (`bool`):
607
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
608
+ tuple.
609
+
610
+ Returns:
611
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
612
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
613
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
614
+ """
615
+
616
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
617
+ raise ValueError(
618
+ (
619
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
620
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
621
+ " one of the `scheduler.timesteps` as a timestep."
622
+ ),
623
+ )
624
+
625
+ if not self.is_scale_input_called:
626
+ logger.warning(
627
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
628
+ "See `StableDiffusionPipeline` for a usage example."
629
+ )
630
+
631
+ if self.step_index is None:
632
+ self._init_step_index(timestep)
633
+
634
+ # Upcast to avoid precision issues when computing prev_sample
635
+ sample = sample.to(torch.float32)
636
+
637
+ sigma = self.sigmas[self.step_index]
638
+
639
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
640
+
641
+ sigma_hat = sigma * (gamma + 1)
642
+
643
+ if gamma > 0:
644
+ noise = randn_tensor(
645
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
646
+ )
647
+ eps = noise * s_noise
648
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
649
+
650
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
651
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
652
+ # backwards compatibility
653
+ if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
654
+ pred_original_sample = model_output
655
+ elif self.config.prediction_type == "epsilon":
656
+ pred_original_sample = sample - sigma_hat * model_output
657
+ elif self.config.prediction_type == "v_prediction":
658
+ # denoised = model_output * c_out + input * c_skip
659
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
660
+ else:
661
+ raise ValueError(
662
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
663
+ )
664
+
665
+ # 2. Convert to an ODE derivative
666
+ derivative = (sample - pred_original_sample) / sigma_hat
667
+
668
+ dt = self.sigmas[self.step_index + 1] - sigma_hat
669
+
670
+ prev_sample = sample + derivative * dt
671
+
672
+ # Cast sample back to model compatible dtype
673
+ prev_sample = prev_sample.to(model_output.dtype)
674
+
675
+ # upon completion increase step index by one
676
+ self._step_index += 1
677
+
678
+ if not return_dict:
679
+ return (
680
+ prev_sample,
681
+ pred_original_sample,
682
+ )
683
+
684
+ return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
685
+
686
+ def add_noise(
687
+ self,
688
+ original_samples: torch.Tensor,
689
+ noise: torch.Tensor,
690
+ timesteps: torch.Tensor,
691
+ ) -> torch.Tensor:
692
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
693
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
694
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
695
+ # mps does not support float64
696
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
697
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
698
+ else:
699
+ schedule_timesteps = self.timesteps.to(original_samples.device)
700
+ timesteps = timesteps.to(original_samples.device)
701
+
702
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
703
+ if self.begin_index is None:
704
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
705
+ elif self.step_index is not None:
706
+ # add_noise is called after first denoising step (for inpainting)
707
+ step_indices = [self.step_index] * timesteps.shape[0]
708
+ else:
709
+ # add noise is called before first denoising step to create initial latent(img2img)
710
+ step_indices = [self.begin_index] * timesteps.shape[0]
711
+
712
+ sigma = sigmas[step_indices].flatten()
713
+ while len(sigma.shape) < len(original_samples.shape):
714
+ sigma = sigma.unsqueeze(-1)
715
+
716
+ noisy_samples = original_samples + noise * sigma
717
+ return noisy_samples
718
+
719
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
720
+ if (
721
+ isinstance(timesteps, int)
722
+ or isinstance(timesteps, torch.IntTensor)
723
+ or isinstance(timesteps, torch.LongTensor)
724
+ ):
725
+ raise ValueError(
726
+ (
727
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
728
+ " `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass"
729
+ " one of the `scheduler.timesteps` as a timestep."
730
+ ),
731
+ )
732
+
733
+ if sample.device.type == "mps" and torch.is_floating_point(timesteps):
734
+ # mps does not support float64
735
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
736
+ timesteps = timesteps.to(sample.device, dtype=torch.float32)
737
+ else:
738
+ schedule_timesteps = self.timesteps.to(sample.device)
739
+ timesteps = timesteps.to(sample.device)
740
+
741
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
742
+ alphas_cumprod = self.alphas_cumprod.to(sample)
743
+ sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5
744
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
745
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
746
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
747
+
748
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5
749
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
750
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
751
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
752
+
753
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
754
+ return velocity
755
+
756
+ def __len__(self):
757
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_euler_discrete_flax.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import flax
19
+ import jax.numpy as jnp
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from .scheduling_utils_flax import (
23
+ CommonSchedulerState,
24
+ FlaxKarrasDiffusionSchedulers,
25
+ FlaxSchedulerMixin,
26
+ FlaxSchedulerOutput,
27
+ broadcast_to_shape_from_left,
28
+ )
29
+
30
+
31
+ @flax.struct.dataclass
32
+ class EulerDiscreteSchedulerState:
33
+ common: CommonSchedulerState
34
+
35
+ # setable values
36
+ init_noise_sigma: jnp.ndarray
37
+ timesteps: jnp.ndarray
38
+ sigmas: jnp.ndarray
39
+ num_inference_steps: Optional[int] = None
40
+
41
+ @classmethod
42
+ def create(
43
+ cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
44
+ ):
45
+ return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
46
+
47
+
48
+ @dataclass
49
+ class FlaxEulerDiscreteSchedulerOutput(FlaxSchedulerOutput):
50
+ state: EulerDiscreteSchedulerState
51
+
52
+
53
+ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
54
+ """
55
+ Euler scheduler (Algorithm 2) from Karras et al. (2022) https://huggingface.co/papers/2206.00364. . Based on the
56
+ original k-diffusion implementation by Katherine Crowson:
57
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51
58
+
59
+
60
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
61
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
62
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
63
+ [`~SchedulerMixin.from_pretrained`] functions.
64
+
65
+ Args:
66
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
67
+ beta_start (`float`): the starting `beta` value of inference.
68
+ beta_end (`float`): the final `beta` value.
69
+ beta_schedule (`str`):
70
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
71
+ `linear` or `scaled_linear`.
72
+ trained_betas (`jnp.ndarray`, optional):
73
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
74
+ prediction_type (`str`, default `epsilon`, optional):
75
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
76
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
77
+ https://imagen.research.google/video/paper.pdf)
78
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
79
+ the `dtype` used for params and computation.
80
+ """
81
+
82
+ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
83
+
84
+ dtype: jnp.dtype
85
+
86
+ @property
87
+ def has_state(self):
88
+ return True
89
+
90
+ @register_to_config
91
+ def __init__(
92
+ self,
93
+ num_train_timesteps: int = 1000,
94
+ beta_start: float = 0.0001,
95
+ beta_end: float = 0.02,
96
+ beta_schedule: str = "linear",
97
+ trained_betas: Optional[jnp.ndarray] = None,
98
+ prediction_type: str = "epsilon",
99
+ timestep_spacing: str = "linspace",
100
+ dtype: jnp.dtype = jnp.float32,
101
+ ):
102
+ self.dtype = dtype
103
+
104
+ def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState:
105
+ if common is None:
106
+ common = CommonSchedulerState.create(self)
107
+
108
+ timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
109
+ sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
110
+ sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas)
111
+ sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
112
+
113
+ # standard deviation of the initial noise distribution
114
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
115
+ init_noise_sigma = sigmas.max()
116
+ else:
117
+ init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5
118
+
119
+ return EulerDiscreteSchedulerState.create(
120
+ common=common,
121
+ init_noise_sigma=init_noise_sigma,
122
+ timesteps=timesteps,
123
+ sigmas=sigmas,
124
+ )
125
+
126
+ def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
127
+ """
128
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
129
+
130
+ Args:
131
+ state (`EulerDiscreteSchedulerState`):
132
+ the `FlaxEulerDiscreteScheduler` state data class instance.
133
+ sample (`jnp.ndarray`):
134
+ current instance of sample being created by diffusion process.
135
+ timestep (`int`):
136
+ current discrete timestep in the diffusion chain.
137
+
138
+ Returns:
139
+ `jnp.ndarray`: scaled input sample
140
+ """
141
+ (step_index,) = jnp.where(state.timesteps == timestep, size=1)
142
+ step_index = step_index[0]
143
+
144
+ sigma = state.sigmas[step_index]
145
+ sample = sample / ((sigma**2 + 1) ** 0.5)
146
+ return sample
147
+
148
+ def set_timesteps(
149
+ self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
150
+ ) -> EulerDiscreteSchedulerState:
151
+ """
152
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
153
+
154
+ Args:
155
+ state (`EulerDiscreteSchedulerState`):
156
+ the `FlaxEulerDiscreteScheduler` state data class instance.
157
+ num_inference_steps (`int`):
158
+ the number of diffusion steps used when generating samples with a pre-trained model.
159
+ """
160
+
161
+ if self.config.timestep_spacing == "linspace":
162
+ timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
163
+ elif self.config.timestep_spacing == "leading":
164
+ step_ratio = self.config.num_train_timesteps // num_inference_steps
165
+ timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
166
+ timesteps += 1
167
+ else:
168
+ raise ValueError(
169
+ f"timestep_spacing must be one of ['linspace', 'leading'], got {self.config.timestep_spacing}"
170
+ )
171
+
172
+ sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
173
+ sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas)
174
+ sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
175
+
176
+ # standard deviation of the initial noise distribution
177
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
178
+ init_noise_sigma = sigmas.max()
179
+ else:
180
+ init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5
181
+
182
+ return state.replace(
183
+ timesteps=timesteps,
184
+ sigmas=sigmas,
185
+ num_inference_steps=num_inference_steps,
186
+ init_noise_sigma=init_noise_sigma,
187
+ )
188
+
189
+ def step(
190
+ self,
191
+ state: EulerDiscreteSchedulerState,
192
+ model_output: jnp.ndarray,
193
+ timestep: int,
194
+ sample: jnp.ndarray,
195
+ return_dict: bool = True,
196
+ ) -> Union[FlaxEulerDiscreteSchedulerOutput, Tuple]:
197
+ """
198
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
199
+ process from the learned model outputs (most often the predicted noise).
200
+
201
+ Args:
202
+ state (`EulerDiscreteSchedulerState`):
203
+ the `FlaxEulerDiscreteScheduler` state data class instance.
204
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
205
+ timestep (`int`): current discrete timestep in the diffusion chain.
206
+ sample (`jnp.ndarray`):
207
+ current instance of sample being created by diffusion process.
208
+ order: coefficient for multi-step inference.
209
+ return_dict (`bool`): option for returning tuple rather than FlaxEulerDiscreteScheduler class
210
+
211
+ Returns:
212
+ [`FlaxEulerDiscreteScheduler`] or `tuple`: [`FlaxEulerDiscreteScheduler`] if `return_dict` is True,
213
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
214
+
215
+ """
216
+ if state.num_inference_steps is None:
217
+ raise ValueError(
218
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
219
+ )
220
+
221
+ (step_index,) = jnp.where(state.timesteps == timestep, size=1)
222
+ step_index = step_index[0]
223
+
224
+ sigma = state.sigmas[step_index]
225
+
226
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
227
+ if self.config.prediction_type == "epsilon":
228
+ pred_original_sample = sample - sigma * model_output
229
+ elif self.config.prediction_type == "v_prediction":
230
+ # * c_out + input * c_skip
231
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
232
+ else:
233
+ raise ValueError(
234
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
235
+ )
236
+
237
+ # 2. Convert to an ODE derivative
238
+ derivative = (sample - pred_original_sample) / sigma
239
+
240
+ # dt = sigma_down - sigma
241
+ dt = state.sigmas[step_index + 1] - sigma
242
+
243
+ prev_sample = sample + derivative * dt
244
+
245
+ if not return_dict:
246
+ return (prev_sample, state)
247
+
248
+ return FlaxEulerDiscreteSchedulerOutput(prev_sample=prev_sample, state=state)
249
+
250
+ def add_noise(
251
+ self,
252
+ state: EulerDiscreteSchedulerState,
253
+ original_samples: jnp.ndarray,
254
+ noise: jnp.ndarray,
255
+ timesteps: jnp.ndarray,
256
+ ) -> jnp.ndarray:
257
+ sigma = state.sigmas[timesteps].flatten()
258
+ sigma = broadcast_to_shape_from_left(sigma, noise.shape)
259
+
260
+ noisy_samples = original_samples + noise * sigma
261
+
262
+ return noisy_samples
263
+
264
+ def __len__(self):
265
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils import BaseOutput, is_scipy_available, logging
24
+ from .scheduling_utils import SchedulerMixin
25
+
26
+
27
+ if is_scipy_available():
28
+ import scipy.stats
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ @dataclass
34
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
35
+ """
36
+ Output class for the scheduler's `step` function output.
37
+
38
+ Args:
39
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41
+ denoising loop.
42
+ """
43
+
44
+ prev_sample: torch.FloatTensor
45
+
46
+
47
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
48
+ """
49
+ Euler scheduler.
50
+
51
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
52
+ methods the library implements for all schedulers such as loading and saving.
53
+
54
+ Args:
55
+ num_train_timesteps (`int`, defaults to 1000):
56
+ The number of diffusion steps to train the model.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ use_dynamic_shifting (`bool`, defaults to False):
60
+ Whether to apply timestep shifting on-the-fly based on the image resolution.
61
+ base_shift (`float`, defaults to 0.5):
62
+ Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent
63
+ with desired output.
64
+ max_shift (`float`, defaults to 1.15):
65
+ Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be
66
+ more exaggerated or stylized.
67
+ base_image_seq_len (`int`, defaults to 256):
68
+ The base image sequence length.
69
+ max_image_seq_len (`int`, defaults to 4096):
70
+ The maximum image sequence length.
71
+ invert_sigmas (`bool`, defaults to False):
72
+ Whether to invert the sigmas.
73
+ shift_terminal (`float`, defaults to None):
74
+ The end value of the shifted timestep schedule.
75
+ use_karras_sigmas (`bool`, defaults to False):
76
+ Whether to use Karras sigmas for step sizes in the noise schedule during sampling.
77
+ use_exponential_sigmas (`bool`, defaults to False):
78
+ Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
79
+ use_beta_sigmas (`bool`, defaults to False):
80
+ Whether to use beta sigmas for step sizes in the noise schedule during sampling.
81
+ time_shift_type (`str`, defaults to "exponential"):
82
+ The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
83
+ stochastic_sampling (`bool`, defaults to False):
84
+ Whether to use stochastic sampling.
85
+ """
86
+
87
+ _compatibles = []
88
+ order = 1
89
+
90
+ @register_to_config
91
+ def __init__(
92
+ self,
93
+ num_train_timesteps: int = 1000,
94
+ shift: float = 1.0,
95
+ use_dynamic_shifting: bool = False,
96
+ base_shift: Optional[float] = 0.5,
97
+ max_shift: Optional[float] = 1.15,
98
+ base_image_seq_len: Optional[int] = 256,
99
+ max_image_seq_len: Optional[int] = 4096,
100
+ invert_sigmas: bool = False,
101
+ shift_terminal: Optional[float] = None,
102
+ use_karras_sigmas: Optional[bool] = False,
103
+ use_exponential_sigmas: Optional[bool] = False,
104
+ use_beta_sigmas: Optional[bool] = False,
105
+ time_shift_type: str = "exponential",
106
+ stochastic_sampling: bool = False,
107
+ ):
108
+ if self.config.use_beta_sigmas and not is_scipy_available():
109
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
110
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
111
+ raise ValueError(
112
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
113
+ )
114
+ if time_shift_type not in {"exponential", "linear"}:
115
+ raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.")
116
+
117
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
118
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
119
+
120
+ sigmas = timesteps / num_train_timesteps
121
+ if not use_dynamic_shifting:
122
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
123
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
124
+
125
+ self.timesteps = sigmas * num_train_timesteps
126
+
127
+ self._step_index = None
128
+ self._begin_index = None
129
+
130
+ self._shift = shift
131
+
132
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
133
+ self.sigma_min = self.sigmas[-1].item()
134
+ self.sigma_max = self.sigmas[0].item()
135
+
136
+ @property
137
+ def shift(self):
138
+ """
139
+ The value used for shifting.
140
+ """
141
+ return self._shift
142
+
143
+ @property
144
+ def step_index(self):
145
+ """
146
+ The index counter for current timestep. It will increase 1 after each scheduler step.
147
+ """
148
+ return self._step_index
149
+
150
+ @property
151
+ def begin_index(self):
152
+ """
153
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
154
+ """
155
+ return self._begin_index
156
+
157
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
158
+ def set_begin_index(self, begin_index: int = 0):
159
+ """
160
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
161
+
162
+ Args:
163
+ begin_index (`int`):
164
+ The begin index for the scheduler.
165
+ """
166
+ self._begin_index = begin_index
167
+
168
+ def set_shift(self, shift: float):
169
+ self._shift = shift
170
+
171
+ def scale_noise(
172
+ self,
173
+ sample: torch.FloatTensor,
174
+ timestep: Union[float, torch.FloatTensor],
175
+ noise: Optional[torch.FloatTensor] = None,
176
+ ) -> torch.FloatTensor:
177
+ """
178
+ Forward process in flow-matching
179
+
180
+ Args:
181
+ sample (`torch.FloatTensor`):
182
+ The input sample.
183
+ timestep (`int`, *optional*):
184
+ The current timestep in the diffusion chain.
185
+
186
+ Returns:
187
+ `torch.FloatTensor`:
188
+ A scaled input sample.
189
+ """
190
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
191
+ sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
192
+
193
+ if sample.device.type == "mps" and torch.is_floating_point(timestep):
194
+ # mps does not support float64
195
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
196
+ timestep = timestep.to(sample.device, dtype=torch.float32)
197
+ else:
198
+ schedule_timesteps = self.timesteps.to(sample.device)
199
+ timestep = timestep.to(sample.device)
200
+
201
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
202
+ if self.begin_index is None:
203
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
204
+ elif self.step_index is not None:
205
+ # add_noise is called after first denoising step (for inpainting)
206
+ step_indices = [self.step_index] * timestep.shape[0]
207
+ else:
208
+ # add noise is called before first denoising step to create initial latent(img2img)
209
+ step_indices = [self.begin_index] * timestep.shape[0]
210
+
211
+ sigma = sigmas[step_indices].flatten()
212
+ while len(sigma.shape) < len(sample.shape):
213
+ sigma = sigma.unsqueeze(-1)
214
+
215
+ sample = sigma * noise + (1.0 - sigma) * sample
216
+
217
+ return sample
218
+
219
+ def _sigma_to_t(self, sigma):
220
+ return sigma * self.config.num_train_timesteps
221
+
222
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
223
+ if self.config.time_shift_type == "exponential":
224
+ return self._time_shift_exponential(mu, sigma, t)
225
+ elif self.config.time_shift_type == "linear":
226
+ return self._time_shift_linear(mu, sigma, t)
227
+
228
+ def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
229
+ r"""
230
+ Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
231
+ value.
232
+
233
+ Reference:
234
+ https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
235
+
236
+ Args:
237
+ t (`torch.Tensor`):
238
+ A tensor of timesteps to be stretched and shifted.
239
+
240
+ Returns:
241
+ `torch.Tensor`:
242
+ A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
243
+ """
244
+ one_minus_z = 1 - t
245
+ scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
246
+ stretched_t = 1 - (one_minus_z / scale_factor)
247
+ return stretched_t
248
+
249
+ def set_timesteps(
250
+ self,
251
+ num_inference_steps: Optional[int] = None,
252
+ device: Union[str, torch.device] = None,
253
+ sigmas: Optional[List[float]] = None,
254
+ mu: Optional[float] = None,
255
+ timesteps: Optional[List[float]] = None,
256
+ ):
257
+ """
258
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
259
+
260
+ Args:
261
+ num_inference_steps (`int`, *optional*):
262
+ The number of diffusion steps used when generating samples with a pre-trained model.
263
+ device (`str` or `torch.device`, *optional*):
264
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
265
+ sigmas (`List[float]`, *optional*):
266
+ Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
267
+ automatically.
268
+ mu (`float`, *optional*):
269
+ Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
270
+ shifting.
271
+ timesteps (`List[float]`, *optional*):
272
+ Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
273
+ automatically.
274
+ """
275
+ if self.config.use_dynamic_shifting and mu is None:
276
+ raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
277
+
278
+ if sigmas is not None and timesteps is not None:
279
+ if len(sigmas) != len(timesteps):
280
+ raise ValueError("`sigmas` and `timesteps` should have the same length")
281
+
282
+ if num_inference_steps is not None:
283
+ if (sigmas is not None and len(sigmas) != num_inference_steps) or (
284
+ timesteps is not None and len(timesteps) != num_inference_steps
285
+ ):
286
+ raise ValueError(
287
+ "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
288
+ )
289
+ else:
290
+ num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
291
+
292
+ self.num_inference_steps = num_inference_steps
293
+
294
+ # 1. Prepare default sigmas
295
+ is_timesteps_provided = timesteps is not None
296
+
297
+ if is_timesteps_provided:
298
+ timesteps = np.array(timesteps).astype(np.float32)
299
+
300
+ if sigmas is None:
301
+ if timesteps is None:
302
+ timesteps = np.linspace(
303
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
304
+ )
305
+ sigmas = timesteps / self.config.num_train_timesteps
306
+ else:
307
+ sigmas = np.array(sigmas).astype(np.float32)
308
+ num_inference_steps = len(sigmas)
309
+
310
+ # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
311
+ # "exponential" or "linear" type is applied
312
+ if self.config.use_dynamic_shifting:
313
+ sigmas = self.time_shift(mu, 1.0, sigmas)
314
+ else:
315
+ sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
316
+
317
+ # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
318
+ if self.config.shift_terminal:
319
+ sigmas = self.stretch_shift_to_terminal(sigmas)
320
+
321
+ # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
322
+ if self.config.use_karras_sigmas:
323
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
324
+ elif self.config.use_exponential_sigmas:
325
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
326
+ elif self.config.use_beta_sigmas:
327
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
328
+
329
+ # 5. Convert sigmas and timesteps to tensors and move to specified device
330
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
331
+ if not is_timesteps_provided:
332
+ timesteps = sigmas * self.config.num_train_timesteps
333
+ else:
334
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
335
+
336
+ # 6. Append the terminal sigma value.
337
+ # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
338
+ # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
339
+ if self.config.invert_sigmas:
340
+ sigmas = 1.0 - sigmas
341
+ timesteps = sigmas * self.config.num_train_timesteps
342
+ sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
343
+ else:
344
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
345
+
346
+ self.timesteps = timesteps
347
+ self.sigmas = sigmas
348
+ self._step_index = None
349
+ self._begin_index = None
350
+
351
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
352
+ if schedule_timesteps is None:
353
+ schedule_timesteps = self.timesteps
354
+
355
+ indices = (schedule_timesteps == timestep).nonzero()
356
+
357
+ # The sigma index that is taken for the **very** first `step`
358
+ # is always the second index (or the last index if there is only 1)
359
+ # This way we can ensure we don't accidentally skip a sigma in
360
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
361
+ pos = 1 if len(indices) > 1 else 0
362
+
363
+ return indices[pos].item()
364
+
365
+ def _init_step_index(self, timestep):
366
+ if self.begin_index is None:
367
+ if isinstance(timestep, torch.Tensor):
368
+ timestep = timestep.to(self.timesteps.device)
369
+ self._step_index = self.index_for_timestep(timestep)
370
+ else:
371
+ self._step_index = self._begin_index
372
+
373
+ def step(
374
+ self,
375
+ model_output: torch.FloatTensor,
376
+ timestep: Union[float, torch.FloatTensor],
377
+ sample: torch.FloatTensor,
378
+ s_churn: float = 0.0,
379
+ s_tmin: float = 0.0,
380
+ s_tmax: float = float("inf"),
381
+ s_noise: float = 1.0,
382
+ generator: Optional[torch.Generator] = None,
383
+ per_token_timesteps: Optional[torch.Tensor] = None,
384
+ return_dict: bool = True,
385
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
386
+ """
387
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
388
+ process from the learned model outputs (most often the predicted noise).
389
+
390
+ Args:
391
+ model_output (`torch.FloatTensor`):
392
+ The direct output from learned diffusion model.
393
+ timestep (`float`):
394
+ The current discrete timestep in the diffusion chain.
395
+ sample (`torch.FloatTensor`):
396
+ A current instance of a sample created by the diffusion process.
397
+ s_churn (`float`):
398
+ s_tmin (`float`):
399
+ s_tmax (`float`):
400
+ s_noise (`float`, defaults to 1.0):
401
+ Scaling factor for noise added to the sample.
402
+ generator (`torch.Generator`, *optional*):
403
+ A random number generator.
404
+ per_token_timesteps (`torch.Tensor`, *optional*):
405
+ The timesteps for each token in the sample.
406
+ return_dict (`bool`):
407
+ Whether or not to return a
408
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
409
+
410
+ Returns:
411
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`:
412
+ If return_dict is `True`,
413
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned,
414
+ otherwise a tuple is returned where the first element is the sample tensor.
415
+ """
416
+
417
+ if (
418
+ isinstance(timestep, int)
419
+ or isinstance(timestep, torch.IntTensor)
420
+ or isinstance(timestep, torch.LongTensor)
421
+ ):
422
+ raise ValueError(
423
+ (
424
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
425
+ " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass"
426
+ " one of the `scheduler.timesteps` as a timestep."
427
+ ),
428
+ )
429
+
430
+ if self.step_index is None:
431
+ self._init_step_index(timestep)
432
+
433
+ # Upcast to avoid precision issues when computing prev_sample
434
+ sample = sample.to(torch.float32)
435
+
436
+ if per_token_timesteps is not None:
437
+ per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
438
+
439
+ sigmas = self.sigmas[:, None, None]
440
+ lower_mask = sigmas < per_token_sigmas[None] - 1e-6
441
+ lower_sigmas = lower_mask * sigmas
442
+ lower_sigmas, _ = lower_sigmas.max(dim=0)
443
+
444
+ current_sigma = per_token_sigmas[..., None]
445
+ next_sigma = lower_sigmas[..., None]
446
+ dt = current_sigma - next_sigma
447
+ else:
448
+ sigma_idx = self.step_index
449
+ sigma = self.sigmas[sigma_idx]
450
+ sigma_next = self.sigmas[sigma_idx + 1]
451
+
452
+ current_sigma = sigma
453
+ next_sigma = sigma_next
454
+ dt = sigma_next - sigma
455
+
456
+ if self.config.stochastic_sampling:
457
+ x0 = sample - current_sigma * model_output
458
+ noise = torch.randn_like(sample)
459
+ prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
460
+ else:
461
+ prev_sample = sample + dt * model_output
462
+
463
+ # upon completion increase step index by one
464
+ self._step_index += 1
465
+ if per_token_timesteps is None:
466
+ # Cast sample back to model compatible dtype
467
+ prev_sample = prev_sample.to(model_output.dtype)
468
+
469
+ if not return_dict:
470
+ return (prev_sample,)
471
+
472
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
473
+
474
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
475
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
476
+ """Constructs the noise schedule of Karras et al. (2022)."""
477
+
478
+ # Hack to make sure that other schedulers which copy this function don't break
479
+ # TODO: Add this logic to the other schedulers
480
+ if hasattr(self.config, "sigma_min"):
481
+ sigma_min = self.config.sigma_min
482
+ else:
483
+ sigma_min = None
484
+
485
+ if hasattr(self.config, "sigma_max"):
486
+ sigma_max = self.config.sigma_max
487
+ else:
488
+ sigma_max = None
489
+
490
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
491
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
492
+
493
+ rho = 7.0 # 7.0 is the value used in the paper
494
+ ramp = np.linspace(0, 1, num_inference_steps)
495
+ min_inv_rho = sigma_min ** (1 / rho)
496
+ max_inv_rho = sigma_max ** (1 / rho)
497
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
498
+ return sigmas
499
+
500
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
501
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
502
+ """Constructs an exponential noise schedule."""
503
+
504
+ # Hack to make sure that other schedulers which copy this function don't break
505
+ # TODO: Add this logic to the other schedulers
506
+ if hasattr(self.config, "sigma_min"):
507
+ sigma_min = self.config.sigma_min
508
+ else:
509
+ sigma_min = None
510
+
511
+ if hasattr(self.config, "sigma_max"):
512
+ sigma_max = self.config.sigma_max
513
+ else:
514
+ sigma_max = None
515
+
516
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
517
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
518
+
519
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
520
+ return sigmas
521
+
522
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
523
+ def _convert_to_beta(
524
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
525
+ ) -> torch.Tensor:
526
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
527
+
528
+ # Hack to make sure that other schedulers which copy this function don't break
529
+ # TODO: Add this logic to the other schedulers
530
+ if hasattr(self.config, "sigma_min"):
531
+ sigma_min = self.config.sigma_min
532
+ else:
533
+ sigma_min = None
534
+
535
+ if hasattr(self.config, "sigma_max"):
536
+ sigma_max = self.config.sigma_max
537
+ else:
538
+ sigma_max = None
539
+
540
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
541
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
542
+
543
+ sigmas = np.array(
544
+ [
545
+ sigma_min + (ppf * (sigma_max - sigma_min))
546
+ for ppf in [
547
+ scipy.stats.beta.ppf(timestep, alpha, beta)
548
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
549
+ ]
550
+ ]
551
+ )
552
+ return sigmas
553
+
554
+ def _time_shift_exponential(self, mu, sigma, t):
555
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
556
+
557
+ def _time_shift_linear(self, mu, sigma, t):
558
+ return mu / (mu + (1 / t - 1) ** sigma)
559
+
560
+ def __len__(self):
561
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_flow_match_heun_discrete.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput, logging
23
+ from ..utils.torch_utils import randn_tensor
24
+ from .scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Heun scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 2
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ shift: float = 1.0,
69
+ ):
70
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
71
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
72
+
73
+ sigmas = timesteps / num_train_timesteps
74
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
75
+
76
+ self.timesteps = sigmas * num_train_timesteps
77
+
78
+ self._step_index = None
79
+ self._begin_index = None
80
+
81
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
82
+ self.sigma_min = self.sigmas[-1].item()
83
+ self.sigma_max = self.sigmas[0].item()
84
+
85
+ @property
86
+ def step_index(self):
87
+ """
88
+ The index counter for current timestep. It will increase 1 after each scheduler step.
89
+ """
90
+ return self._step_index
91
+
92
+ @property
93
+ def begin_index(self):
94
+ """
95
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
96
+ """
97
+ return self._begin_index
98
+
99
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
100
+ def set_begin_index(self, begin_index: int = 0):
101
+ """
102
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
103
+
104
+ Args:
105
+ begin_index (`int`):
106
+ The begin index for the scheduler.
107
+ """
108
+ self._begin_index = begin_index
109
+
110
+ def scale_noise(
111
+ self,
112
+ sample: torch.FloatTensor,
113
+ timestep: Union[float, torch.FloatTensor],
114
+ noise: Optional[torch.FloatTensor] = None,
115
+ ) -> torch.FloatTensor:
116
+ """
117
+ Forward process in flow-matching
118
+
119
+ Args:
120
+ sample (`torch.FloatTensor`):
121
+ The input sample.
122
+ timestep (`int`, *optional*):
123
+ The current timestep in the diffusion chain.
124
+
125
+ Returns:
126
+ `torch.FloatTensor`:
127
+ A scaled input sample.
128
+ """
129
+ if self.step_index is None:
130
+ self._init_step_index(timestep)
131
+
132
+ sigma = self.sigmas[self.step_index]
133
+ sample = sigma * noise + (1.0 - sigma) * sample
134
+
135
+ return sample
136
+
137
+ def _sigma_to_t(self, sigma):
138
+ return sigma * self.config.num_train_timesteps
139
+
140
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
141
+ """
142
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
143
+
144
+ Args:
145
+ num_inference_steps (`int`):
146
+ The number of diffusion steps used when generating samples with a pre-trained model.
147
+ device (`str` or `torch.device`, *optional*):
148
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
149
+ """
150
+ self.num_inference_steps = num_inference_steps
151
+
152
+ timesteps = np.linspace(
153
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
154
+ )
155
+
156
+ sigmas = timesteps / self.config.num_train_timesteps
157
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
158
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
159
+
160
+ timesteps = sigmas * self.config.num_train_timesteps
161
+ timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
162
+ self.timesteps = timesteps.to(device=device)
163
+
164
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
165
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
166
+
167
+ # empty dt and derivative
168
+ self.prev_derivative = None
169
+ self.dt = None
170
+
171
+ self._step_index = None
172
+ self._begin_index = None
173
+
174
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
175
+ if schedule_timesteps is None:
176
+ schedule_timesteps = self.timesteps
177
+
178
+ indices = (schedule_timesteps == timestep).nonzero()
179
+
180
+ # The sigma index that is taken for the **very** first `step`
181
+ # is always the second index (or the last index if there is only 1)
182
+ # This way we can ensure we don't accidentally skip a sigma in
183
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
184
+ pos = 1 if len(indices) > 1 else 0
185
+
186
+ return indices[pos].item()
187
+
188
+ def _init_step_index(self, timestep):
189
+ if self.begin_index is None:
190
+ if isinstance(timestep, torch.Tensor):
191
+ timestep = timestep.to(self.timesteps.device)
192
+ self._step_index = self.index_for_timestep(timestep)
193
+ else:
194
+ self._step_index = self._begin_index
195
+
196
+ @property
197
+ def state_in_first_order(self):
198
+ return self.dt is None
199
+
200
+ def step(
201
+ self,
202
+ model_output: torch.FloatTensor,
203
+ timestep: Union[float, torch.FloatTensor],
204
+ sample: torch.FloatTensor,
205
+ s_churn: float = 0.0,
206
+ s_tmin: float = 0.0,
207
+ s_tmax: float = float("inf"),
208
+ s_noise: float = 1.0,
209
+ generator: Optional[torch.Generator] = None,
210
+ return_dict: bool = True,
211
+ ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
212
+ """
213
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
214
+ process from the learned model outputs (most often the predicted noise).
215
+
216
+ Args:
217
+ model_output (`torch.FloatTensor`):
218
+ The direct output from learned diffusion model.
219
+ timestep (`float`):
220
+ The current discrete timestep in the diffusion chain.
221
+ sample (`torch.FloatTensor`):
222
+ A current instance of a sample created by the diffusion process.
223
+ s_churn (`float`):
224
+ s_tmin (`float`):
225
+ s_tmax (`float`):
226
+ s_noise (`float`, defaults to 1.0):
227
+ Scaling factor for noise added to the sample.
228
+ generator (`torch.Generator`, *optional*):
229
+ A random number generator.
230
+ return_dict (`bool`):
231
+ Whether or not to return a
232
+ [`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] tuple.
233
+
234
+ Returns:
235
+ [`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] or `tuple`:
236
+ If return_dict is `True`,
237
+ [`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] is returned,
238
+ otherwise a tuple is returned where the first element is the sample tensor.
239
+ """
240
+
241
+ if (
242
+ isinstance(timestep, int)
243
+ or isinstance(timestep, torch.IntTensor)
244
+ or isinstance(timestep, torch.LongTensor)
245
+ ):
246
+ raise ValueError(
247
+ (
248
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
249
+ " `FlowMatchHeunDiscreteScheduler.step()` is not supported. Make sure to pass"
250
+ " one of the `scheduler.timesteps` as a timestep."
251
+ ),
252
+ )
253
+
254
+ if self.step_index is None:
255
+ self._init_step_index(timestep)
256
+
257
+ # Upcast to avoid precision issues when computing prev_sample
258
+ sample = sample.to(torch.float32)
259
+
260
+ if self.state_in_first_order:
261
+ sigma = self.sigmas[self.step_index]
262
+ sigma_next = self.sigmas[self.step_index + 1]
263
+ else:
264
+ # 2nd order / Heun's method
265
+ sigma = self.sigmas[self.step_index - 1]
266
+ sigma_next = self.sigmas[self.step_index]
267
+
268
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
269
+
270
+ sigma_hat = sigma * (gamma + 1)
271
+
272
+ if gamma > 0:
273
+ noise = randn_tensor(
274
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
275
+ )
276
+ eps = noise * s_noise
277
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
278
+
279
+ if self.state_in_first_order:
280
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
281
+ denoised = sample - model_output * sigma
282
+ # 2. convert to an ODE derivative for 1st order
283
+ derivative = (sample - denoised) / sigma_hat
284
+ # 3. Delta timestep
285
+ dt = sigma_next - sigma_hat
286
+
287
+ # store for 2nd order step
288
+ self.prev_derivative = derivative
289
+ self.dt = dt
290
+ self.sample = sample
291
+ else:
292
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
293
+ denoised = sample - model_output * sigma_next
294
+ # 2. 2nd order / Heun's method
295
+ derivative = (sample - denoised) / sigma_next
296
+ derivative = 0.5 * (self.prev_derivative + derivative)
297
+
298
+ # 3. take prev timestep & sample
299
+ dt = self.dt
300
+ sample = self.sample
301
+
302
+ # free dt and derivative
303
+ # Note, this puts the scheduler in "first order mode"
304
+ self.prev_derivative = None
305
+ self.dt = None
306
+ self.sample = None
307
+
308
+ prev_sample = sample + derivative * dt
309
+ # Cast sample back to model compatible dtype
310
+ prev_sample = prev_sample.to(model_output.dtype)
311
+
312
+ # upon completion increase step index by one
313
+ self._step_index += 1
314
+
315
+ if not return_dict:
316
+ return (prev_sample,)
317
+
318
+ return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
319
+
320
+ def __len__(self):
321
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/fsspec/__init__.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import caching
2
+ from ._version import __version__ # noqa: F401
3
+ from .callbacks import Callback
4
+ from .compression import available_compressions
5
+ from .core import get_fs_token_paths, open, open_files, open_local, url_to_fs
6
+ from .exceptions import FSTimeoutError
7
+ from .mapping import FSMap, get_mapper
8
+ from .registry import (
9
+ available_protocols,
10
+ filesystem,
11
+ get_filesystem_class,
12
+ register_implementation,
13
+ registry,
14
+ )
15
+ from .spec import AbstractFileSystem
16
+
17
+ __all__ = [
18
+ "AbstractFileSystem",
19
+ "FSTimeoutError",
20
+ "FSMap",
21
+ "filesystem",
22
+ "register_implementation",
23
+ "get_filesystem_class",
24
+ "get_fs_token_paths",
25
+ "get_mapper",
26
+ "open",
27
+ "open_files",
28
+ "open_local",
29
+ "registry",
30
+ "caching",
31
+ "Callback",
32
+ "available_protocols",
33
+ "available_compressions",
34
+ "url_to_fs",
35
+ ]
36
+
37
+
38
+ def process_entries():
39
+ try:
40
+ from importlib.metadata import entry_points
41
+ except ImportError:
42
+ return
43
+ if entry_points is not None:
44
+ try:
45
+ eps = entry_points()
46
+ except TypeError:
47
+ pass # importlib-metadata < 0.8
48
+ else:
49
+ if hasattr(eps, "select"): # Python 3.10+ / importlib_metadata >= 3.9.0
50
+ specs = eps.select(group="fsspec.specs")
51
+ else:
52
+ specs = eps.get("fsspec.specs", [])
53
+ registered_names = {}
54
+ for spec in specs:
55
+ err_msg = f"Unable to load filesystem from {spec}"
56
+ name = spec.name
57
+ if name in registered_names:
58
+ continue
59
+ registered_names[name] = True
60
+ register_implementation(
61
+ name,
62
+ spec.value.replace(":", "."),
63
+ errtxt=err_msg,
64
+ # We take our implementations as the ones to overload with if
65
+ # for some reason we encounter some, may be the same, already
66
+ # registered
67
+ clobber=True,
68
+ )
69
+
70
+
71
+ process_entries()
pythonProject/.venv/Lib/site-packages/fsspec/_version.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '2025.9.0'
32
+ __version_tuple__ = version_tuple = (2025, 9, 0)
33
+
34
+ __commit_id__ = commit_id = None
pythonProject/.venv/Lib/site-packages/fsspec/implementations/arrow.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import errno
2
+ import io
3
+ import os
4
+ import secrets
5
+ import shutil
6
+ from contextlib import suppress
7
+ from functools import cached_property, wraps
8
+ from urllib.parse import parse_qs
9
+
10
+ from fsspec.spec import AbstractFileSystem
11
+ from fsspec.utils import (
12
+ get_package_version_without_import,
13
+ infer_storage_options,
14
+ mirror_from,
15
+ tokenize,
16
+ )
17
+
18
+
19
+ def wrap_exceptions(func):
20
+ @wraps(func)
21
+ def wrapper(*args, **kwargs):
22
+ try:
23
+ return func(*args, **kwargs)
24
+ except OSError as exception:
25
+ if not exception.args:
26
+ raise
27
+
28
+ message, *args = exception.args
29
+ if isinstance(message, str) and "does not exist" in message:
30
+ raise FileNotFoundError(errno.ENOENT, message) from exception
31
+ else:
32
+ raise
33
+
34
+ return wrapper
35
+
36
+
37
+ PYARROW_VERSION = None
38
+
39
+
40
+ class ArrowFSWrapper(AbstractFileSystem):
41
+ """FSSpec-compatible wrapper of pyarrow.fs.FileSystem.
42
+
43
+ Parameters
44
+ ----------
45
+ fs : pyarrow.fs.FileSystem
46
+
47
+ """
48
+
49
+ root_marker = "/"
50
+
51
+ def __init__(self, fs, **kwargs):
52
+ global PYARROW_VERSION
53
+ PYARROW_VERSION = get_package_version_without_import("pyarrow")
54
+ self.fs = fs
55
+ super().__init__(**kwargs)
56
+
57
+ @property
58
+ def protocol(self):
59
+ return self.fs.type_name
60
+
61
+ @cached_property
62
+ def fsid(self):
63
+ return "hdfs_" + tokenize(self.fs.host, self.fs.port)
64
+
65
+ @classmethod
66
+ def _strip_protocol(cls, path):
67
+ ops = infer_storage_options(path)
68
+ path = ops["path"]
69
+ if path.startswith("//"):
70
+ # special case for "hdfs://path" (without the triple slash)
71
+ path = path[1:]
72
+ return path
73
+
74
+ def ls(self, path, detail=False, **kwargs):
75
+ path = self._strip_protocol(path)
76
+ from pyarrow.fs import FileSelector
77
+
78
+ entries = [
79
+ self._make_entry(entry)
80
+ for entry in self.fs.get_file_info(FileSelector(path))
81
+ ]
82
+ if detail:
83
+ return entries
84
+ else:
85
+ return [entry["name"] for entry in entries]
86
+
87
+ def info(self, path, **kwargs):
88
+ path = self._strip_protocol(path)
89
+ [info] = self.fs.get_file_info([path])
90
+ return self._make_entry(info)
91
+
92
+ def exists(self, path):
93
+ path = self._strip_protocol(path)
94
+ try:
95
+ self.info(path)
96
+ except FileNotFoundError:
97
+ return False
98
+ else:
99
+ return True
100
+
101
+ def _make_entry(self, info):
102
+ from pyarrow.fs import FileType
103
+
104
+ if info.type is FileType.Directory:
105
+ kind = "directory"
106
+ elif info.type is FileType.File:
107
+ kind = "file"
108
+ elif info.type is FileType.NotFound:
109
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), info.path)
110
+ else:
111
+ kind = "other"
112
+
113
+ return {
114
+ "name": info.path,
115
+ "size": info.size,
116
+ "type": kind,
117
+ "mtime": info.mtime,
118
+ }
119
+
120
+ @wrap_exceptions
121
+ def cp_file(self, path1, path2, **kwargs):
122
+ path1 = self._strip_protocol(path1).rstrip("/")
123
+ path2 = self._strip_protocol(path2).rstrip("/")
124
+
125
+ with self._open(path1, "rb") as lstream:
126
+ tmp_fname = f"{path2}.tmp.{secrets.token_hex(6)}"
127
+ try:
128
+ with self.open(tmp_fname, "wb") as rstream:
129
+ shutil.copyfileobj(lstream, rstream)
130
+ self.fs.move(tmp_fname, path2)
131
+ except BaseException:
132
+ with suppress(FileNotFoundError):
133
+ self.fs.delete_file(tmp_fname)
134
+ raise
135
+
136
+ @wrap_exceptions
137
+ def mv(self, path1, path2, **kwargs):
138
+ path1 = self._strip_protocol(path1).rstrip("/")
139
+ path2 = self._strip_protocol(path2).rstrip("/")
140
+ self.fs.move(path1, path2)
141
+
142
+ @wrap_exceptions
143
+ def rm_file(self, path):
144
+ path = self._strip_protocol(path)
145
+ self.fs.delete_file(path)
146
+
147
+ @wrap_exceptions
148
+ def rm(self, path, recursive=False, maxdepth=None):
149
+ path = self._strip_protocol(path).rstrip("/")
150
+ if self.isdir(path):
151
+ if recursive:
152
+ self.fs.delete_dir(path)
153
+ else:
154
+ raise ValueError("Can't delete directories without recursive=False")
155
+ else:
156
+ self.fs.delete_file(path)
157
+
158
+ @wrap_exceptions
159
+ def _open(self, path, mode="rb", block_size=None, seekable=True, **kwargs):
160
+ if mode == "rb":
161
+ if seekable:
162
+ method = self.fs.open_input_file
163
+ else:
164
+ method = self.fs.open_input_stream
165
+ elif mode == "wb":
166
+ method = self.fs.open_output_stream
167
+ elif mode == "ab":
168
+ method = self.fs.open_append_stream
169
+ else:
170
+ raise ValueError(f"unsupported mode for Arrow filesystem: {mode!r}")
171
+
172
+ _kwargs = {}
173
+ if mode != "rb" or not seekable:
174
+ if int(PYARROW_VERSION.split(".")[0]) >= 4:
175
+ # disable compression auto-detection
176
+ _kwargs["compression"] = None
177
+ stream = method(path, **_kwargs)
178
+
179
+ return ArrowFile(self, stream, path, mode, block_size, **kwargs)
180
+
181
+ @wrap_exceptions
182
+ def mkdir(self, path, create_parents=True, **kwargs):
183
+ path = self._strip_protocol(path)
184
+ if create_parents:
185
+ self.makedirs(path, exist_ok=True)
186
+ else:
187
+ self.fs.create_dir(path, recursive=False)
188
+
189
+ @wrap_exceptions
190
+ def makedirs(self, path, exist_ok=False):
191
+ path = self._strip_protocol(path)
192
+ self.fs.create_dir(path, recursive=True)
193
+
194
+ @wrap_exceptions
195
+ def rmdir(self, path):
196
+ path = self._strip_protocol(path)
197
+ self.fs.delete_dir(path)
198
+
199
+ @wrap_exceptions
200
+ def modified(self, path):
201
+ path = self._strip_protocol(path)
202
+ return self.fs.get_file_info(path).mtime
203
+
204
+ def cat_file(self, path, start=None, end=None, **kwargs):
205
+ kwargs["seekable"] = start not in [None, 0]
206
+ return super().cat_file(path, start=None, end=None, **kwargs)
207
+
208
+ def get_file(self, rpath, lpath, **kwargs):
209
+ kwargs["seekable"] = False
210
+ super().get_file(rpath, lpath, **kwargs)
211
+
212
+
213
+ @mirror_from(
214
+ "stream",
215
+ [
216
+ "read",
217
+ "seek",
218
+ "tell",
219
+ "write",
220
+ "readable",
221
+ "writable",
222
+ "close",
223
+ "size",
224
+ "seekable",
225
+ ],
226
+ )
227
+ class ArrowFile(io.IOBase):
228
+ def __init__(self, fs, stream, path, mode, block_size=None, **kwargs):
229
+ self.path = path
230
+ self.mode = mode
231
+
232
+ self.fs = fs
233
+ self.stream = stream
234
+
235
+ self.blocksize = self.block_size = block_size
236
+ self.kwargs = kwargs
237
+
238
+ def __enter__(self):
239
+ return self
240
+
241
+ def __exit__(self, *args):
242
+ return self.close()
243
+
244
+
245
+ class HadoopFileSystem(ArrowFSWrapper):
246
+ """A wrapper on top of the pyarrow.fs.HadoopFileSystem
247
+ to connect it's interface with fsspec"""
248
+
249
+ protocol = "hdfs"
250
+
251
+ def __init__(
252
+ self,
253
+ host="default",
254
+ port=0,
255
+ user=None,
256
+ kerb_ticket=None,
257
+ replication=3,
258
+ extra_conf=None,
259
+ **kwargs,
260
+ ):
261
+ """
262
+
263
+ Parameters
264
+ ----------
265
+ host: str
266
+ Hostname, IP or "default" to try to read from Hadoop config
267
+ port: int
268
+ Port to connect on, or default from Hadoop config if 0
269
+ user: str or None
270
+ If given, connect as this username
271
+ kerb_ticket: str or None
272
+ If given, use this ticket for authentication
273
+ replication: int
274
+ set replication factor of file for write operations. default value is 3.
275
+ extra_conf: None or dict
276
+ Passed on to HadoopFileSystem
277
+ """
278
+ from pyarrow.fs import HadoopFileSystem
279
+
280
+ fs = HadoopFileSystem(
281
+ host=host,
282
+ port=port,
283
+ user=user,
284
+ kerb_ticket=kerb_ticket,
285
+ replication=replication,
286
+ extra_conf=extra_conf,
287
+ )
288
+ super().__init__(fs=fs, **kwargs)
289
+
290
+ @staticmethod
291
+ def _get_kwargs_from_urls(path):
292
+ ops = infer_storage_options(path)
293
+ out = {}
294
+ if ops.get("host", None):
295
+ out["host"] = ops["host"]
296
+ if ops.get("username", None):
297
+ out["user"] = ops["username"]
298
+ if ops.get("port", None):
299
+ out["port"] = ops["port"]
300
+ if ops.get("url_query", None):
301
+ queries = parse_qs(ops["url_query"])
302
+ if queries.get("replication", None):
303
+ out["replication"] = int(queries["replication"][0])
304
+ return out
pythonProject/.venv/Lib/site-packages/fsspec/implementations/asyn_wrapper.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import functools
3
+ import inspect
4
+
5
+ import fsspec
6
+ from fsspec.asyn import AsyncFileSystem, running_async
7
+
8
+
9
+ def async_wrapper(func, obj=None, semaphore=None):
10
+ """
11
+ Wraps a synchronous function to make it awaitable.
12
+
13
+ Parameters
14
+ ----------
15
+ func : callable
16
+ The synchronous function to wrap.
17
+ obj : object, optional
18
+ The instance to bind the function to, if applicable.
19
+ semaphore : asyncio.Semaphore, optional
20
+ A semaphore to limit concurrent calls.
21
+
22
+ Returns
23
+ -------
24
+ coroutine
25
+ An awaitable version of the function.
26
+ """
27
+
28
+ @functools.wraps(func)
29
+ async def wrapper(*args, **kwargs):
30
+ if semaphore:
31
+ async with semaphore:
32
+ return await asyncio.to_thread(func, *args, **kwargs)
33
+ return await asyncio.to_thread(func, *args, **kwargs)
34
+
35
+ return wrapper
36
+
37
+
38
+ class AsyncFileSystemWrapper(AsyncFileSystem):
39
+ """
40
+ A wrapper class to convert a synchronous filesystem into an asynchronous one.
41
+
42
+ This class takes an existing synchronous filesystem implementation and wraps all
43
+ its methods to provide an asynchronous interface.
44
+
45
+ Parameters
46
+ ----------
47
+ sync_fs : AbstractFileSystem
48
+ The synchronous filesystem instance to wrap.
49
+ """
50
+
51
+ protocol = "asyncwrapper", "async_wrapper"
52
+ cachable = False
53
+
54
+ def __init__(
55
+ self,
56
+ fs=None,
57
+ asynchronous=None,
58
+ target_protocol=None,
59
+ target_options=None,
60
+ semaphore=None,
61
+ max_concurrent_tasks=None,
62
+ **kwargs,
63
+ ):
64
+ if asynchronous is None:
65
+ asynchronous = running_async()
66
+ super().__init__(asynchronous=asynchronous, **kwargs)
67
+ if fs is not None:
68
+ self.sync_fs = fs
69
+ else:
70
+ self.sync_fs = fsspec.filesystem(target_protocol, **target_options)
71
+ self.protocol = self.sync_fs.protocol
72
+ self.semaphore = semaphore
73
+ self._wrap_all_sync_methods()
74
+
75
+ @property
76
+ def fsid(self):
77
+ return f"async_{self.sync_fs.fsid}"
78
+
79
+ def _wrap_all_sync_methods(self):
80
+ """
81
+ Wrap all synchronous methods of the underlying filesystem with asynchronous versions.
82
+ """
83
+ excluded_methods = {"open"}
84
+ for method_name in dir(self.sync_fs):
85
+ if method_name.startswith("_") or method_name in excluded_methods:
86
+ continue
87
+
88
+ attr = inspect.getattr_static(self.sync_fs, method_name)
89
+ if isinstance(attr, property):
90
+ continue
91
+
92
+ method = getattr(self.sync_fs, method_name)
93
+ if callable(method) and not inspect.iscoroutinefunction(method):
94
+ async_method = async_wrapper(method, obj=self, semaphore=self.semaphore)
95
+ setattr(self, f"_{method_name}", async_method)
96
+
97
+ @classmethod
98
+ def wrap_class(cls, sync_fs_class):
99
+ """
100
+ Create a new class that can be used to instantiate an AsyncFileSystemWrapper
101
+ with lazy instantiation of the underlying synchronous filesystem.
102
+
103
+ Parameters
104
+ ----------
105
+ sync_fs_class : type
106
+ The class of the synchronous filesystem to wrap.
107
+
108
+ Returns
109
+ -------
110
+ type
111
+ A new class that wraps the provided synchronous filesystem class.
112
+ """
113
+
114
+ class GeneratedAsyncFileSystemWrapper(cls):
115
+ def __init__(self, *args, **kwargs):
116
+ sync_fs = sync_fs_class(*args, **kwargs)
117
+ super().__init__(sync_fs)
118
+
119
+ GeneratedAsyncFileSystemWrapper.__name__ = (
120
+ f"Async{sync_fs_class.__name__}Wrapper"
121
+ )
122
+ return GeneratedAsyncFileSystemWrapper
pythonProject/.venv/Lib/site-packages/fsspec/implementations/cache_mapper.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import hashlib
5
+
6
+ from fsspec.implementations.local import make_path_posix
7
+
8
+
9
+ class AbstractCacheMapper(abc.ABC):
10
+ """Abstract super-class for mappers from remote URLs to local cached
11
+ basenames.
12
+ """
13
+
14
+ @abc.abstractmethod
15
+ def __call__(self, path: str) -> str: ...
16
+
17
+ def __eq__(self, other: object) -> bool:
18
+ # Identity only depends on class. When derived classes have attributes
19
+ # they will need to be included.
20
+ return isinstance(other, type(self))
21
+
22
+ def __hash__(self) -> int:
23
+ # Identity only depends on class. When derived classes have attributes
24
+ # they will need to be included.
25
+ return hash(type(self))
26
+
27
+
28
+ class BasenameCacheMapper(AbstractCacheMapper):
29
+ """Cache mapper that uses the basename of the remote URL and a fixed number
30
+ of directory levels above this.
31
+
32
+ The default is zero directory levels, meaning different paths with the same
33
+ basename will have the same cached basename.
34
+ """
35
+
36
+ def __init__(self, directory_levels: int = 0):
37
+ if directory_levels < 0:
38
+ raise ValueError(
39
+ "BasenameCacheMapper requires zero or positive directory_levels"
40
+ )
41
+ self.directory_levels = directory_levels
42
+
43
+ # Separator for directories when encoded as strings.
44
+ self._separator = "_@_"
45
+
46
+ def __call__(self, path: str) -> str:
47
+ path = make_path_posix(path)
48
+ prefix, *bits = path.rsplit("/", self.directory_levels + 1)
49
+ if bits:
50
+ return self._separator.join(bits)
51
+ else:
52
+ return prefix # No separator found, simple filename
53
+
54
+ def __eq__(self, other: object) -> bool:
55
+ return super().__eq__(other) and self.directory_levels == other.directory_levels
56
+
57
+ def __hash__(self) -> int:
58
+ return super().__hash__() ^ hash(self.directory_levels)
59
+
60
+
61
+ class HashCacheMapper(AbstractCacheMapper):
62
+ """Cache mapper that uses a hash of the remote URL."""
63
+
64
+ def __call__(self, path: str) -> str:
65
+ return hashlib.sha256(path.encode()).hexdigest()
66
+
67
+
68
+ def create_cache_mapper(same_names: bool) -> AbstractCacheMapper:
69
+ """Factory method to create cache mapper for backward compatibility with
70
+ ``CachingFileSystem`` constructor using ``same_names`` kwarg.
71
+ """
72
+ if same_names:
73
+ return BasenameCacheMapper()
74
+ else:
75
+ return HashCacheMapper()
pythonProject/.venv/Lib/site-packages/fsspec/implementations/cache_metadata.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pickle
5
+ import time
6
+ from typing import TYPE_CHECKING
7
+
8
+ from fsspec.utils import atomic_write
9
+
10
+ try:
11
+ import ujson as json
12
+ except ImportError:
13
+ if not TYPE_CHECKING:
14
+ import json
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import Iterator
18
+ from typing import Any, Literal
19
+
20
+ from typing_extensions import TypeAlias
21
+
22
+ from .cached import CachingFileSystem
23
+
24
+ Detail: TypeAlias = dict[str, Any]
25
+
26
+
27
+ class CacheMetadata:
28
+ """Cache metadata.
29
+
30
+ All reading and writing of cache metadata is performed by this class,
31
+ accessing the cached files and blocks is not.
32
+
33
+ Metadata is stored in a single file per storage directory in JSON format.
34
+ For backward compatibility, also reads metadata stored in pickle format
35
+ which is converted to JSON when next saved.
36
+ """
37
+
38
+ def __init__(self, storage: list[str]):
39
+ """
40
+
41
+ Parameters
42
+ ----------
43
+ storage: list[str]
44
+ Directories containing cached files, must be at least one. Metadata
45
+ is stored in the last of these directories by convention.
46
+ """
47
+ if not storage:
48
+ raise ValueError("CacheMetadata expects at least one storage location")
49
+
50
+ self._storage = storage
51
+ self.cached_files: list[Detail] = [{}]
52
+
53
+ # Private attribute to force saving of metadata in pickle format rather than
54
+ # JSON for use in tests to confirm can read both pickle and JSON formats.
55
+ self._force_save_pickle = False
56
+
57
+ def _load(self, fn: str) -> Detail:
58
+ """Low-level function to load metadata from specific file"""
59
+ try:
60
+ with open(fn, "r") as f:
61
+ loaded = json.load(f)
62
+ except ValueError:
63
+ with open(fn, "rb") as f:
64
+ loaded = pickle.load(f)
65
+ for c in loaded.values():
66
+ if isinstance(c.get("blocks"), list):
67
+ c["blocks"] = set(c["blocks"])
68
+ return loaded
69
+
70
+ def _save(self, metadata_to_save: Detail, fn: str) -> None:
71
+ """Low-level function to save metadata to specific file"""
72
+ if self._force_save_pickle:
73
+ with atomic_write(fn) as f:
74
+ pickle.dump(metadata_to_save, f)
75
+ else:
76
+ with atomic_write(fn, mode="w") as f:
77
+ json.dump(metadata_to_save, f)
78
+
79
+ def _scan_locations(
80
+ self, writable_only: bool = False
81
+ ) -> Iterator[tuple[str, str, bool]]:
82
+ """Yield locations (filenames) where metadata is stored, and whether
83
+ writable or not.
84
+
85
+ Parameters
86
+ ----------
87
+ writable: bool
88
+ Set to True to only yield writable locations.
89
+
90
+ Returns
91
+ -------
92
+ Yields (str, str, bool)
93
+ """
94
+ n = len(self._storage)
95
+ for i, storage in enumerate(self._storage):
96
+ writable = i == n - 1
97
+ if writable_only and not writable:
98
+ continue
99
+ yield os.path.join(storage, "cache"), storage, writable
100
+
101
+ def check_file(
102
+ self, path: str, cfs: CachingFileSystem | None
103
+ ) -> Literal[False] | tuple[Detail, str]:
104
+ """If path is in cache return its details, otherwise return ``False``.
105
+
106
+ If the optional CachingFileSystem is specified then it is used to
107
+ perform extra checks to reject possible matches, such as if they are
108
+ too old.
109
+ """
110
+ for (fn, base, _), cache in zip(self._scan_locations(), self.cached_files):
111
+ if path not in cache:
112
+ continue
113
+ detail = cache[path].copy()
114
+
115
+ if cfs is not None:
116
+ if cfs.check_files and detail["uid"] != cfs.fs.ukey(path):
117
+ # Wrong file as determined by hash of file properties
118
+ continue
119
+ if cfs.expiry and time.time() - detail["time"] > cfs.expiry:
120
+ # Cached file has expired
121
+ continue
122
+
123
+ fn = os.path.join(base, detail["fn"])
124
+ if os.path.exists(fn):
125
+ return detail, fn
126
+ return False
127
+
128
+ def clear_expired(self, expiry_time: int) -> tuple[list[str], bool]:
129
+ """Remove expired metadata from the cache.
130
+
131
+ Returns names of files corresponding to expired metadata and a boolean
132
+ flag indicating whether the writable cache is empty. Caller is
133
+ responsible for deleting the expired files.
134
+ """
135
+ expired_files = []
136
+ for path, detail in self.cached_files[-1].copy().items():
137
+ if time.time() - detail["time"] > expiry_time:
138
+ fn = detail.get("fn", "")
139
+ if not fn:
140
+ raise RuntimeError(
141
+ f"Cache metadata does not contain 'fn' for {path}"
142
+ )
143
+ fn = os.path.join(self._storage[-1], fn)
144
+ expired_files.append(fn)
145
+ self.cached_files[-1].pop(path)
146
+
147
+ if self.cached_files[-1]:
148
+ cache_path = os.path.join(self._storage[-1], "cache")
149
+ self._save(self.cached_files[-1], cache_path)
150
+
151
+ writable_cache_empty = not self.cached_files[-1]
152
+ return expired_files, writable_cache_empty
153
+
154
+ def load(self) -> None:
155
+ """Load all metadata from disk and store in ``self.cached_files``"""
156
+ cached_files = []
157
+ for fn, _, _ in self._scan_locations():
158
+ if os.path.exists(fn):
159
+ # TODO: consolidate blocks here
160
+ cached_files.append(self._load(fn))
161
+ else:
162
+ cached_files.append({})
163
+ self.cached_files = cached_files or [{}]
164
+
165
+ def on_close_cached_file(self, f: Any, path: str) -> None:
166
+ """Perform side-effect actions on closing a cached file.
167
+
168
+ The actual closing of the file is the responsibility of the caller.
169
+ """
170
+ # File must be writeble, so in self.cached_files[-1]
171
+ c = self.cached_files[-1][path]
172
+ if c["blocks"] is not True and len(c["blocks"]) * f.blocksize >= f.size:
173
+ c["blocks"] = True
174
+
175
+ def pop_file(self, path: str) -> str | None:
176
+ """Remove metadata of cached file.
177
+
178
+ If path is in the cache, return the filename of the cached file,
179
+ otherwise return ``None``. Caller is responsible for deleting the
180
+ cached file.
181
+ """
182
+ details = self.check_file(path, None)
183
+ if not details:
184
+ return None
185
+ _, fn = details
186
+ if fn.startswith(self._storage[-1]):
187
+ self.cached_files[-1].pop(path)
188
+ self.save()
189
+ else:
190
+ raise PermissionError(
191
+ "Can only delete cached file in last, writable cache location"
192
+ )
193
+ return fn
194
+
195
+ def save(self) -> None:
196
+ """Save metadata to disk"""
197
+ for (fn, _, writable), cache in zip(self._scan_locations(), self.cached_files):
198
+ if not writable:
199
+ continue
200
+
201
+ if os.path.exists(fn):
202
+ cached_files = self._load(fn)
203
+ for k, c in cached_files.items():
204
+ if k in cache:
205
+ if c["blocks"] is True or cache[k]["blocks"] is True:
206
+ c["blocks"] = True
207
+ else:
208
+ # self.cached_files[*][*]["blocks"] must continue to
209
+ # point to the same set object so that updates
210
+ # performed by MMapCache are propagated back to
211
+ # self.cached_files.
212
+ blocks = cache[k]["blocks"]
213
+ blocks.update(c["blocks"])
214
+ c["blocks"] = blocks
215
+ c["time"] = max(c["time"], cache[k]["time"])
216
+ c["uid"] = cache[k]["uid"]
217
+
218
+ # Files can be added to cache after it was written once
219
+ for k, c in cache.items():
220
+ if k not in cached_files:
221
+ cached_files[k] = c
222
+ else:
223
+ cached_files = cache
224
+ cache = {k: v.copy() for k, v in cached_files.items()}
225
+ for c in cache.values():
226
+ if isinstance(c["blocks"], set):
227
+ c["blocks"] = list(c["blocks"])
228
+ self._save(cache, fn)
229
+ self.cached_files[-1] = cached_files
230
+
231
+ def update_file(self, path: str, detail: Detail) -> None:
232
+ """Update metadata for specific file in memory, do not save"""
233
+ self.cached_files[-1][path] = detail
pythonProject/.venv/Lib/site-packages/fsspec/implementations/cached.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import logging
5
+ import os
6
+ import tempfile
7
+ import time
8
+ import weakref
9
+ from shutil import rmtree
10
+ from typing import TYPE_CHECKING, Any, Callable, ClassVar
11
+
12
+ from fsspec import AbstractFileSystem, filesystem
13
+ from fsspec.callbacks import DEFAULT_CALLBACK
14
+ from fsspec.compression import compr
15
+ from fsspec.core import BaseCache, MMapCache
16
+ from fsspec.exceptions import BlocksizeMismatchError
17
+ from fsspec.implementations.cache_mapper import create_cache_mapper
18
+ from fsspec.implementations.cache_metadata import CacheMetadata
19
+ from fsspec.implementations.local import LocalFileSystem
20
+ from fsspec.spec import AbstractBufferedFile
21
+ from fsspec.transaction import Transaction
22
+ from fsspec.utils import infer_compression
23
+
24
+ if TYPE_CHECKING:
25
+ from fsspec.implementations.cache_mapper import AbstractCacheMapper
26
+
27
+ logger = logging.getLogger("fsspec.cached")
28
+
29
+
30
+ class WriteCachedTransaction(Transaction):
31
+ def complete(self, commit=True):
32
+ rpaths = [f.path for f in self.files]
33
+ lpaths = [f.fn for f in self.files]
34
+ if commit:
35
+ self.fs.put(lpaths, rpaths)
36
+ self.files.clear()
37
+ self.fs._intrans = False
38
+ self.fs._transaction = None
39
+ self.fs = None # break cycle
40
+
41
+
42
+ class CachingFileSystem(AbstractFileSystem):
43
+ """Locally caching filesystem, layer over any other FS
44
+
45
+ This class implements chunk-wise local storage of remote files, for quick
46
+ access after the initial download. The files are stored in a given
47
+ directory with hashes of URLs for the filenames. If no directory is given,
48
+ a temporary one is used, which should be cleaned up by the OS after the
49
+ process ends. The files themselves are sparse (as implemented in
50
+ :class:`~fsspec.caching.MMapCache`), so only the data which is accessed
51
+ takes up space.
52
+
53
+ Restrictions:
54
+
55
+ - the block-size must be the same for each access of a given file, unless
56
+ all blocks of the file have already been read
57
+ - caching can only be applied to file-systems which produce files
58
+ derived from fsspec.spec.AbstractBufferedFile ; LocalFileSystem is also
59
+ allowed, for testing
60
+ """
61
+
62
+ protocol: ClassVar[str | tuple[str, ...]] = ("blockcache", "cached")
63
+
64
+ def __init__(
65
+ self,
66
+ target_protocol=None,
67
+ cache_storage="TMP",
68
+ cache_check=10,
69
+ check_files=False,
70
+ expiry_time=604800,
71
+ target_options=None,
72
+ fs=None,
73
+ same_names: bool | None = None,
74
+ compression=None,
75
+ cache_mapper: AbstractCacheMapper | None = None,
76
+ **kwargs,
77
+ ):
78
+ """
79
+
80
+ Parameters
81
+ ----------
82
+ target_protocol: str (optional)
83
+ Target filesystem protocol. Provide either this or ``fs``.
84
+ cache_storage: str or list(str)
85
+ Location to store files. If "TMP", this is a temporary directory,
86
+ and will be cleaned up by the OS when this process ends (or later).
87
+ If a list, each location will be tried in the order given, but
88
+ only the last will be considered writable.
89
+ cache_check: int
90
+ Number of seconds between reload of cache metadata
91
+ check_files: bool
92
+ Whether to explicitly see if the UID of the remote file matches
93
+ the stored one before using. Warning: some file systems such as
94
+ HTTP cannot reliably give a unique hash of the contents of some
95
+ path, so be sure to set this option to False.
96
+ expiry_time: int
97
+ The time in seconds after which a local copy is considered useless.
98
+ Set to falsy to prevent expiry. The default is equivalent to one
99
+ week.
100
+ target_options: dict or None
101
+ Passed to the instantiation of the FS, if fs is None.
102
+ fs: filesystem instance
103
+ The target filesystem to run against. Provide this or ``protocol``.
104
+ same_names: bool (optional)
105
+ By default, target URLs are hashed using a ``HashCacheMapper`` so
106
+ that files from different backends with the same basename do not
107
+ conflict. If this argument is ``true``, a ``BasenameCacheMapper``
108
+ is used instead. Other cache mapper options are available by using
109
+ the ``cache_mapper`` keyword argument. Only one of this and
110
+ ``cache_mapper`` should be specified.
111
+ compression: str (optional)
112
+ To decompress on download. Can be 'infer' (guess from the URL name),
113
+ one of the entries in ``fsspec.compression.compr``, or None for no
114
+ decompression.
115
+ cache_mapper: AbstractCacheMapper (optional)
116
+ The object use to map from original filenames to cached filenames.
117
+ Only one of this and ``same_names`` should be specified.
118
+ """
119
+ super().__init__(**kwargs)
120
+ if fs is None and target_protocol is None:
121
+ raise ValueError(
122
+ "Please provide filesystem instance(fs) or target_protocol"
123
+ )
124
+ if not (fs is None) ^ (target_protocol is None):
125
+ raise ValueError(
126
+ "Both filesystems (fs) and target_protocol may not be both given."
127
+ )
128
+ if cache_storage == "TMP":
129
+ tempdir = tempfile.mkdtemp()
130
+ storage = [tempdir]
131
+ weakref.finalize(self, self._remove_tempdir, tempdir)
132
+ else:
133
+ if isinstance(cache_storage, str):
134
+ storage = [cache_storage]
135
+ else:
136
+ storage = cache_storage
137
+ os.makedirs(storage[-1], exist_ok=True)
138
+ self.storage = storage
139
+ self.kwargs = target_options or {}
140
+ self.cache_check = cache_check
141
+ self.check_files = check_files
142
+ self.expiry = expiry_time
143
+ self.compression = compression
144
+
145
+ # Size of cache in bytes. If None then the size is unknown and will be
146
+ # recalculated the next time cache_size() is called. On writes to the
147
+ # cache this is reset to None.
148
+ self._cache_size = None
149
+
150
+ if same_names is not None and cache_mapper is not None:
151
+ raise ValueError(
152
+ "Cannot specify both same_names and cache_mapper in "
153
+ "CachingFileSystem.__init__"
154
+ )
155
+ if cache_mapper is not None:
156
+ self._mapper = cache_mapper
157
+ else:
158
+ self._mapper = create_cache_mapper(
159
+ same_names if same_names is not None else False
160
+ )
161
+
162
+ self.target_protocol = (
163
+ target_protocol
164
+ if isinstance(target_protocol, str)
165
+ else (fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0])
166
+ )
167
+ self._metadata = CacheMetadata(self.storage)
168
+ self.load_cache()
169
+ self.fs = fs if fs is not None else filesystem(target_protocol, **self.kwargs)
170
+
171
+ def _strip_protocol(path):
172
+ # acts as a method, since each instance has a difference target
173
+ return self.fs._strip_protocol(type(self)._strip_protocol(path))
174
+
175
+ self._strip_protocol: Callable = _strip_protocol
176
+
177
+ @staticmethod
178
+ def _remove_tempdir(tempdir):
179
+ try:
180
+ rmtree(tempdir)
181
+ except Exception:
182
+ pass
183
+
184
+ def _mkcache(self):
185
+ os.makedirs(self.storage[-1], exist_ok=True)
186
+
187
+ def cache_size(self):
188
+ """Return size of cache in bytes.
189
+
190
+ If more than one cache directory is in use, only the size of the last
191
+ one (the writable cache directory) is returned.
192
+ """
193
+ if self._cache_size is None:
194
+ cache_dir = self.storage[-1]
195
+ self._cache_size = filesystem("file").du(cache_dir, withdirs=True)
196
+ return self._cache_size
197
+
198
+ def load_cache(self):
199
+ """Read set of stored blocks from file"""
200
+ self._metadata.load()
201
+ self._mkcache()
202
+ self.last_cache = time.time()
203
+
204
+ def save_cache(self):
205
+ """Save set of stored blocks from file"""
206
+ self._mkcache()
207
+ self._metadata.save()
208
+ self.last_cache = time.time()
209
+ self._cache_size = None
210
+
211
+ def _check_cache(self):
212
+ """Reload caches if time elapsed or any disappeared"""
213
+ self._mkcache()
214
+ if not self.cache_check:
215
+ # explicitly told not to bother checking
216
+ return
217
+ timecond = time.time() - self.last_cache > self.cache_check
218
+ existcond = all(os.path.exists(storage) for storage in self.storage)
219
+ if timecond or not existcond:
220
+ self.load_cache()
221
+
222
+ def _check_file(self, path):
223
+ """Is path in cache and still valid"""
224
+ path = self._strip_protocol(path)
225
+ self._check_cache()
226
+ return self._metadata.check_file(path, self)
227
+
228
+ def clear_cache(self):
229
+ """Remove all files and metadata from the cache
230
+
231
+ In the case of multiple cache locations, this clears only the last one,
232
+ which is assumed to be the read/write one.
233
+ """
234
+ rmtree(self.storage[-1])
235
+ self.load_cache()
236
+ self._cache_size = None
237
+
238
+ def clear_expired_cache(self, expiry_time=None):
239
+ """Remove all expired files and metadata from the cache
240
+
241
+ In the case of multiple cache locations, this clears only the last one,
242
+ which is assumed to be the read/write one.
243
+
244
+ Parameters
245
+ ----------
246
+ expiry_time: int
247
+ The time in seconds after which a local copy is considered useless.
248
+ If not defined the default is equivalent to the attribute from the
249
+ file caching instantiation.
250
+ """
251
+
252
+ if not expiry_time:
253
+ expiry_time = self.expiry
254
+
255
+ self._check_cache()
256
+
257
+ expired_files, writable_cache_empty = self._metadata.clear_expired(expiry_time)
258
+ for fn in expired_files:
259
+ if os.path.exists(fn):
260
+ os.remove(fn)
261
+
262
+ if writable_cache_empty:
263
+ rmtree(self.storage[-1])
264
+ self.load_cache()
265
+
266
+ self._cache_size = None
267
+
268
+ def pop_from_cache(self, path):
269
+ """Remove cached version of given file
270
+
271
+ Deletes local copy of the given (remote) path. If it is found in a cache
272
+ location which is not the last, it is assumed to be read-only, and
273
+ raises PermissionError
274
+ """
275
+ path = self._strip_protocol(path)
276
+ fn = self._metadata.pop_file(path)
277
+ if fn is not None:
278
+ os.remove(fn)
279
+ self._cache_size = None
280
+
281
+ def _open(
282
+ self,
283
+ path,
284
+ mode="rb",
285
+ block_size=None,
286
+ autocommit=True,
287
+ cache_options=None,
288
+ **kwargs,
289
+ ):
290
+ """Wrap the target _open
291
+
292
+ If the whole file exists in the cache, just open it locally and
293
+ return that.
294
+
295
+ Otherwise, open the file on the target FS, and make it have a mmap
296
+ cache pointing to the location which we determine, in our cache.
297
+ The ``blocks`` instance is shared, so as the mmap cache instance
298
+ updates, so does the entry in our ``cached_files`` attribute.
299
+ We monkey-patch this file, so that when it closes, we call
300
+ ``close_and_update`` to save the state of the blocks.
301
+ """
302
+ path = self._strip_protocol(path)
303
+
304
+ path = self.fs._strip_protocol(path)
305
+ if "r" not in mode:
306
+ return self.fs._open(
307
+ path,
308
+ mode=mode,
309
+ block_size=block_size,
310
+ autocommit=autocommit,
311
+ cache_options=cache_options,
312
+ **kwargs,
313
+ )
314
+ detail = self._check_file(path)
315
+ if detail:
316
+ # file is in cache
317
+ detail, fn = detail
318
+ hash, blocks = detail["fn"], detail["blocks"]
319
+ if blocks is True:
320
+ # stored file is complete
321
+ logger.debug("Opening local copy of %s", path)
322
+ return open(fn, mode)
323
+ # TODO: action where partial file exists in read-only cache
324
+ logger.debug("Opening partially cached copy of %s", path)
325
+ else:
326
+ hash = self._mapper(path)
327
+ fn = os.path.join(self.storage[-1], hash)
328
+ blocks = set()
329
+ detail = {
330
+ "original": path,
331
+ "fn": hash,
332
+ "blocks": blocks,
333
+ "time": time.time(),
334
+ "uid": self.fs.ukey(path),
335
+ }
336
+ self._metadata.update_file(path, detail)
337
+ logger.debug("Creating local sparse file for %s", path)
338
+
339
+ # explicitly submitting the size to the open call will avoid extra
340
+ # operations when opening. This is particularly relevant
341
+ # for any file that is read over a network, e.g. S3.
342
+ size = detail.get("size")
343
+
344
+ # call target filesystems open
345
+ self._mkcache()
346
+ f = self.fs._open(
347
+ path,
348
+ mode=mode,
349
+ block_size=block_size,
350
+ autocommit=autocommit,
351
+ cache_options=cache_options,
352
+ cache_type="none",
353
+ size=size,
354
+ **kwargs,
355
+ )
356
+
357
+ # set size if not already set
358
+ if size is None:
359
+ detail["size"] = f.size
360
+ self._metadata.update_file(path, detail)
361
+
362
+ if self.compression:
363
+ comp = (
364
+ infer_compression(path)
365
+ if self.compression == "infer"
366
+ else self.compression
367
+ )
368
+ f = compr[comp](f, mode="rb")
369
+ if "blocksize" in detail:
370
+ if detail["blocksize"] != f.blocksize:
371
+ raise BlocksizeMismatchError(
372
+ f"Cached file must be reopened with same block"
373
+ f" size as original (old: {detail['blocksize']},"
374
+ f" new {f.blocksize})"
375
+ )
376
+ else:
377
+ detail["blocksize"] = f.blocksize
378
+
379
+ def _fetch_ranges(ranges):
380
+ return self.fs.cat_ranges(
381
+ [path] * len(ranges),
382
+ [r[0] for r in ranges],
383
+ [r[1] for r in ranges],
384
+ **kwargs,
385
+ )
386
+
387
+ multi_fetcher = None if self.compression else _fetch_ranges
388
+ f.cache = MMapCache(
389
+ f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher
390
+ )
391
+ close = f.close
392
+ f.close = lambda: self.close_and_update(f, close)
393
+ self.save_cache()
394
+ return f
395
+
396
+ def _parent(self, path):
397
+ return self.fs._parent(path)
398
+
399
+ def hash_name(self, path: str, *args: Any) -> str:
400
+ # Kept for backward compatibility with downstream libraries.
401
+ # Ignores extra arguments, previously same_name boolean.
402
+ return self._mapper(path)
403
+
404
+ def close_and_update(self, f, close):
405
+ """Called when a file is closing, so store the set of blocks"""
406
+ if f.closed:
407
+ return
408
+ path = self._strip_protocol(f.path)
409
+ self._metadata.on_close_cached_file(f, path)
410
+ try:
411
+ logger.debug("going to save")
412
+ self.save_cache()
413
+ logger.debug("saved")
414
+ except OSError:
415
+ logger.debug("Cache saving failed while closing file")
416
+ except NameError:
417
+ logger.debug("Cache save failed due to interpreter shutdown")
418
+ close()
419
+ f.closed = True
420
+
421
+ def ls(self, path, detail=True):
422
+ return self.fs.ls(path, detail)
423
+
424
+ def __getattribute__(self, item):
425
+ if item in {
426
+ "load_cache",
427
+ "_open",
428
+ "save_cache",
429
+ "close_and_update",
430
+ "__init__",
431
+ "__getattribute__",
432
+ "__reduce__",
433
+ "_make_local_details",
434
+ "open",
435
+ "cat",
436
+ "cat_file",
437
+ "_cat_file",
438
+ "cat_ranges",
439
+ "_cat_ranges",
440
+ "get",
441
+ "read_block",
442
+ "tail",
443
+ "head",
444
+ "info",
445
+ "ls",
446
+ "exists",
447
+ "isfile",
448
+ "isdir",
449
+ "_check_file",
450
+ "_check_cache",
451
+ "_mkcache",
452
+ "clear_cache",
453
+ "clear_expired_cache",
454
+ "pop_from_cache",
455
+ "local_file",
456
+ "_paths_from_path",
457
+ "get_mapper",
458
+ "open_many",
459
+ "commit_many",
460
+ "hash_name",
461
+ "__hash__",
462
+ "__eq__",
463
+ "to_json",
464
+ "to_dict",
465
+ "cache_size",
466
+ "pipe_file",
467
+ "pipe",
468
+ "start_transaction",
469
+ "end_transaction",
470
+ }:
471
+ # all the methods defined in this class. Note `open` here, since
472
+ # it calls `_open`, but is actually in superclass
473
+ return lambda *args, **kw: getattr(type(self), item).__get__(self)(
474
+ *args, **kw
475
+ )
476
+ if item in ["__reduce_ex__"]:
477
+ raise AttributeError
478
+ if item in ["transaction"]:
479
+ # property
480
+ return type(self).transaction.__get__(self)
481
+ if item in {"_cache", "transaction_type", "protocol"}:
482
+ # class attributes
483
+ return getattr(type(self), item)
484
+ if item == "__class__":
485
+ return type(self)
486
+ d = object.__getattribute__(self, "__dict__")
487
+ fs = d.get("fs", None) # fs is not immediately defined
488
+ if item in d:
489
+ return d[item]
490
+ elif fs is not None:
491
+ if item in fs.__dict__:
492
+ # attribute of instance
493
+ return fs.__dict__[item]
494
+ # attributed belonging to the target filesystem
495
+ cls = type(fs)
496
+ m = getattr(cls, item)
497
+ if (inspect.isfunction(m) or inspect.isdatadescriptor(m)) and (
498
+ not hasattr(m, "__self__") or m.__self__ is None
499
+ ):
500
+ # instance method
501
+ return m.__get__(fs, cls)
502
+ return m # class method or attribute
503
+ else:
504
+ # attributes of the superclass, while target is being set up
505
+ return super().__getattribute__(item)
506
+
507
+ def __eq__(self, other):
508
+ """Test for equality."""
509
+ if self is other:
510
+ return True
511
+ if not isinstance(other, type(self)):
512
+ return False
513
+ return (
514
+ self.storage == other.storage
515
+ and self.kwargs == other.kwargs
516
+ and self.cache_check == other.cache_check
517
+ and self.check_files == other.check_files
518
+ and self.expiry == other.expiry
519
+ and self.compression == other.compression
520
+ and self._mapper == other._mapper
521
+ and self.target_protocol == other.target_protocol
522
+ )
523
+
524
+ def __hash__(self):
525
+ """Calculate hash."""
526
+ return (
527
+ hash(tuple(self.storage))
528
+ ^ hash(str(self.kwargs))
529
+ ^ hash(self.cache_check)
530
+ ^ hash(self.check_files)
531
+ ^ hash(self.expiry)
532
+ ^ hash(self.compression)
533
+ ^ hash(self._mapper)
534
+ ^ hash(self.target_protocol)
535
+ )
536
+
537
+
538
+ class WholeFileCacheFileSystem(CachingFileSystem):
539
+ """Caches whole remote files on first access
540
+
541
+ This class is intended as a layer over any other file system, and
542
+ will make a local copy of each file accessed, so that all subsequent
543
+ reads are local. This is similar to ``CachingFileSystem``, but without
544
+ the block-wise functionality and so can work even when sparse files
545
+ are not allowed. See its docstring for definition of the init
546
+ arguments.
547
+
548
+ The class still needs access to the remote store for listing files,
549
+ and may refresh cached files.
550
+ """
551
+
552
+ protocol = "filecache"
553
+ local_file = True
554
+
555
+ def open_many(self, open_files, **kwargs):
556
+ paths = [of.path for of in open_files]
557
+ if "r" in open_files.mode:
558
+ self._mkcache()
559
+ else:
560
+ return [
561
+ LocalTempFile(
562
+ self.fs,
563
+ path,
564
+ mode=open_files.mode,
565
+ fn=os.path.join(self.storage[-1], self._mapper(path)),
566
+ **kwargs,
567
+ )
568
+ for path in paths
569
+ ]
570
+
571
+ if self.compression:
572
+ raise NotImplementedError
573
+ details = [self._check_file(sp) for sp in paths]
574
+ downpath = [p for p, d in zip(paths, details) if not d]
575
+ downfn0 = [
576
+ os.path.join(self.storage[-1], self._mapper(p))
577
+ for p, d in zip(paths, details)
578
+ ] # keep these path names for opening later
579
+ downfn = [fn for fn, d in zip(downfn0, details) if not d]
580
+ if downpath:
581
+ # skip if all files are already cached and up to date
582
+ self.fs.get(downpath, downfn)
583
+
584
+ # update metadata - only happens when downloads are successful
585
+ newdetail = [
586
+ {
587
+ "original": path,
588
+ "fn": self._mapper(path),
589
+ "blocks": True,
590
+ "time": time.time(),
591
+ "uid": self.fs.ukey(path),
592
+ }
593
+ for path in downpath
594
+ ]
595
+ for path, detail in zip(downpath, newdetail):
596
+ self._metadata.update_file(path, detail)
597
+ self.save_cache()
598
+
599
+ def firstpart(fn):
600
+ # helper to adapt both whole-file and simple-cache
601
+ return fn[1] if isinstance(fn, tuple) else fn
602
+
603
+ return [
604
+ open(firstpart(fn0) if fn0 else fn1, mode=open_files.mode)
605
+ for fn0, fn1 in zip(details, downfn0)
606
+ ]
607
+
608
+ def commit_many(self, open_files):
609
+ self.fs.put([f.fn for f in open_files], [f.path for f in open_files])
610
+ [f.close() for f in open_files]
611
+ for f in open_files:
612
+ # in case autocommit is off, and so close did not already delete
613
+ try:
614
+ os.remove(f.name)
615
+ except FileNotFoundError:
616
+ pass
617
+ self._cache_size = None
618
+
619
+ def _make_local_details(self, path):
620
+ hash = self._mapper(path)
621
+ fn = os.path.join(self.storage[-1], hash)
622
+ detail = {
623
+ "original": path,
624
+ "fn": hash,
625
+ "blocks": True,
626
+ "time": time.time(),
627
+ "uid": self.fs.ukey(path),
628
+ }
629
+ self._metadata.update_file(path, detail)
630
+ logger.debug("Copying %s to local cache", path)
631
+ return fn
632
+
633
+ def cat(
634
+ self,
635
+ path,
636
+ recursive=False,
637
+ on_error="raise",
638
+ callback=DEFAULT_CALLBACK,
639
+ **kwargs,
640
+ ):
641
+ paths = self.expand_path(
642
+ path, recursive=recursive, maxdepth=kwargs.get("maxdepth")
643
+ )
644
+ getpaths = []
645
+ storepaths = []
646
+ fns = []
647
+ out = {}
648
+ for p in paths.copy():
649
+ try:
650
+ detail = self._check_file(p)
651
+ if not detail:
652
+ fn = self._make_local_details(p)
653
+ getpaths.append(p)
654
+ storepaths.append(fn)
655
+ else:
656
+ detail, fn = detail if isinstance(detail, tuple) else (None, detail)
657
+ fns.append(fn)
658
+ except Exception as e:
659
+ if on_error == "raise":
660
+ raise
661
+ if on_error == "return":
662
+ out[p] = e
663
+ paths.remove(p)
664
+
665
+ if getpaths:
666
+ self.fs.get(getpaths, storepaths)
667
+ self.save_cache()
668
+
669
+ callback.set_size(len(paths))
670
+ for p, fn in zip(paths, fns):
671
+ with open(fn, "rb") as f:
672
+ out[p] = f.read()
673
+ callback.relative_update(1)
674
+ if isinstance(path, str) and len(paths) == 1 and recursive is False:
675
+ out = out[paths[0]]
676
+ return out
677
+
678
+ def _open(self, path, mode="rb", **kwargs):
679
+ path = self._strip_protocol(path)
680
+ if "r" not in mode:
681
+ hash = self._mapper(path)
682
+ fn = os.path.join(self.storage[-1], hash)
683
+ user_specified_kwargs = {
684
+ k: v
685
+ for k, v in kwargs.items()
686
+ # those kwargs were added by open(), we don't want them
687
+ if k not in ["autocommit", "block_size", "cache_options"]
688
+ }
689
+ return LocalTempFile(self, path, mode=mode, fn=fn, **user_specified_kwargs)
690
+ detail = self._check_file(path)
691
+ if detail:
692
+ detail, fn = detail
693
+ _, blocks = detail["fn"], detail["blocks"]
694
+ if blocks is True:
695
+ logger.debug("Opening local copy of %s", path)
696
+
697
+ # In order to support downstream filesystems to be able to
698
+ # infer the compression from the original filename, like
699
+ # the `TarFileSystem`, let's extend the `io.BufferedReader`
700
+ # fileobject protocol by adding a dedicated attribute
701
+ # `original`.
702
+ f = open(fn, mode)
703
+ f.original = detail.get("original")
704
+ return f
705
+ else:
706
+ raise ValueError(
707
+ f"Attempt to open partially cached file {path}"
708
+ f" as a wholly cached file"
709
+ )
710
+ else:
711
+ fn = self._make_local_details(path)
712
+ kwargs["mode"] = mode
713
+
714
+ # call target filesystems open
715
+ self._mkcache()
716
+ if self.compression:
717
+ with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2:
718
+ if isinstance(f, AbstractBufferedFile):
719
+ # want no type of caching if just downloading whole thing
720
+ f.cache = BaseCache(0, f.cache.fetcher, f.size)
721
+ comp = (
722
+ infer_compression(path)
723
+ if self.compression == "infer"
724
+ else self.compression
725
+ )
726
+ f = compr[comp](f, mode="rb")
727
+ data = True
728
+ while data:
729
+ block = getattr(f, "blocksize", 5 * 2**20)
730
+ data = f.read(block)
731
+ f2.write(data)
732
+ else:
733
+ self.fs.get_file(path, fn)
734
+ self.save_cache()
735
+ return self._open(path, mode)
736
+
737
+
738
+ class SimpleCacheFileSystem(WholeFileCacheFileSystem):
739
+ """Caches whole remote files on first access
740
+
741
+ This class is intended as a layer over any other file system, and
742
+ will make a local copy of each file accessed, so that all subsequent
743
+ reads are local. This implementation only copies whole files, and
744
+ does not keep any metadata about the download time or file details.
745
+ It is therefore safer to use in multi-threaded/concurrent situations.
746
+
747
+ This is the only of the caching filesystems that supports write: you will
748
+ be given a real local open file, and upon close and commit, it will be
749
+ uploaded to the target filesystem; the writability or the target URL is
750
+ not checked until that time.
751
+
752
+ """
753
+
754
+ protocol = "simplecache"
755
+ local_file = True
756
+ transaction_type = WriteCachedTransaction
757
+
758
+ def __init__(self, **kwargs):
759
+ kw = kwargs.copy()
760
+ for key in ["cache_check", "expiry_time", "check_files"]:
761
+ kw[key] = False
762
+ super().__init__(**kw)
763
+ for storage in self.storage:
764
+ if not os.path.exists(storage):
765
+ os.makedirs(storage, exist_ok=True)
766
+
767
+ def _check_file(self, path):
768
+ self._check_cache()
769
+ sha = self._mapper(path)
770
+ for storage in self.storage:
771
+ fn = os.path.join(storage, sha)
772
+ if os.path.exists(fn):
773
+ return fn
774
+
775
+ def save_cache(self):
776
+ pass
777
+
778
+ def load_cache(self):
779
+ pass
780
+
781
+ def pipe_file(self, path, value=None, **kwargs):
782
+ if self._intrans:
783
+ with self.open(path, "wb") as f:
784
+ f.write(value)
785
+ else:
786
+ super().pipe_file(path, value)
787
+
788
+ def ls(self, path, detail=True, **kwargs):
789
+ path = self._strip_protocol(path)
790
+ details = []
791
+ try:
792
+ details = self.fs.ls(
793
+ path, detail=True, **kwargs
794
+ ).copy() # don't edit original!
795
+ except FileNotFoundError as e:
796
+ ex = e
797
+ else:
798
+ ex = None
799
+ if self._intrans:
800
+ path1 = path.rstrip("/") + "/"
801
+ for f in self.transaction.files:
802
+ if f.path == path:
803
+ details.append(
804
+ {"name": path, "size": f.size or f.tell(), "type": "file"}
805
+ )
806
+ elif f.path.startswith(path1):
807
+ if f.path.count("/") == path1.count("/"):
808
+ details.append(
809
+ {"name": f.path, "size": f.size or f.tell(), "type": "file"}
810
+ )
811
+ else:
812
+ dname = "/".join(f.path.split("/")[: path1.count("/") + 1])
813
+ details.append({"name": dname, "size": 0, "type": "directory"})
814
+ if ex is not None and not details:
815
+ raise ex
816
+ if detail:
817
+ return details
818
+ return sorted(_["name"] for _ in details)
819
+
820
+ def info(self, path, **kwargs):
821
+ path = self._strip_protocol(path)
822
+ if self._intrans:
823
+ f = [_ for _ in self.transaction.files if _.path == path]
824
+ if f:
825
+ size = os.path.getsize(f[0].fn) if f[0].closed else f[0].tell()
826
+ return {"name": path, "size": size, "type": "file"}
827
+ f = any(_.path.startswith(path + "/") for _ in self.transaction.files)
828
+ if f:
829
+ return {"name": path, "size": 0, "type": "directory"}
830
+ return self.fs.info(path, **kwargs)
831
+
832
+ def pipe(self, path, value=None, **kwargs):
833
+ if isinstance(path, str):
834
+ self.pipe_file(self._strip_protocol(path), value, **kwargs)
835
+ elif isinstance(path, dict):
836
+ for k, v in path.items():
837
+ self.pipe_file(self._strip_protocol(k), v, **kwargs)
838
+ else:
839
+ raise ValueError("path must be str or dict")
840
+
841
+ async def _cat_file(self, path, start=None, end=None, **kwargs):
842
+ logger.debug("async cat_file %s", path)
843
+ path = self._strip_protocol(path)
844
+ sha = self._mapper(path)
845
+ fn = self._check_file(path)
846
+
847
+ if not fn:
848
+ fn = os.path.join(self.storage[-1], sha)
849
+ await self.fs._get_file(path, fn, **kwargs)
850
+
851
+ with open(fn, "rb") as f: # noqa ASYNC230
852
+ if start:
853
+ f.seek(start)
854
+ size = -1 if end is None else end - f.tell()
855
+ return f.read(size)
856
+
857
+ async def _cat_ranges(
858
+ self, paths, starts, ends, max_gap=None, on_error="return", **kwargs
859
+ ):
860
+ logger.debug("async cat ranges %s", paths)
861
+ lpaths = []
862
+ rset = set()
863
+ download = []
864
+ rpaths = []
865
+ for p in paths:
866
+ fn = self._check_file(p)
867
+ if fn is None and p not in rset:
868
+ sha = self._mapper(p)
869
+ fn = os.path.join(self.storage[-1], sha)
870
+ download.append(fn)
871
+ rset.add(p)
872
+ rpaths.append(p)
873
+ lpaths.append(fn)
874
+ if download:
875
+ await self.fs._get(rpaths, download, on_error=on_error)
876
+
877
+ return LocalFileSystem().cat_ranges(
878
+ lpaths, starts, ends, max_gap=max_gap, on_error=on_error, **kwargs
879
+ )
880
+
881
+ def cat_ranges(
882
+ self, paths, starts, ends, max_gap=None, on_error="return", **kwargs
883
+ ):
884
+ logger.debug("cat ranges %s", paths)
885
+ lpaths = [self._check_file(p) for p in paths]
886
+ rpaths = [p for l, p in zip(lpaths, paths) if l is False]
887
+ lpaths = [l for l, p in zip(lpaths, paths) if l is False]
888
+ self.fs.get(rpaths, lpaths)
889
+ paths = [self._check_file(p) for p in paths]
890
+ return LocalFileSystem().cat_ranges(
891
+ paths, starts, ends, max_gap=max_gap, on_error=on_error, **kwargs
892
+ )
893
+
894
+ def _open(self, path, mode="rb", **kwargs):
895
+ path = self._strip_protocol(path)
896
+ sha = self._mapper(path)
897
+
898
+ if "r" not in mode:
899
+ fn = os.path.join(self.storage[-1], sha)
900
+ user_specified_kwargs = {
901
+ k: v
902
+ for k, v in kwargs.items()
903
+ if k not in ["autocommit", "block_size", "cache_options"]
904
+ } # those were added by open()
905
+ return LocalTempFile(
906
+ self,
907
+ path,
908
+ mode=mode,
909
+ autocommit=not self._intrans,
910
+ fn=fn,
911
+ **user_specified_kwargs,
912
+ )
913
+ fn = self._check_file(path)
914
+ if fn:
915
+ return open(fn, mode)
916
+
917
+ fn = os.path.join(self.storage[-1], sha)
918
+ logger.debug("Copying %s to local cache", path)
919
+ kwargs["mode"] = mode
920
+
921
+ self._mkcache()
922
+ self._cache_size = None
923
+ if self.compression:
924
+ with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2:
925
+ if isinstance(f, AbstractBufferedFile):
926
+ # want no type of caching if just downloading whole thing
927
+ f.cache = BaseCache(0, f.cache.fetcher, f.size)
928
+ comp = (
929
+ infer_compression(path)
930
+ if self.compression == "infer"
931
+ else self.compression
932
+ )
933
+ f = compr[comp](f, mode="rb")
934
+ data = True
935
+ while data:
936
+ block = getattr(f, "blocksize", 5 * 2**20)
937
+ data = f.read(block)
938
+ f2.write(data)
939
+ else:
940
+ self.fs.get_file(path, fn)
941
+ return self._open(path, mode)
942
+
943
+
944
+ class LocalTempFile:
945
+ """A temporary local file, which will be uploaded on commit"""
946
+
947
+ def __init__(self, fs, path, fn, mode="wb", autocommit=True, seek=0, **kwargs):
948
+ self.fn = fn
949
+ self.fh = open(fn, mode)
950
+ self.mode = mode
951
+ if seek:
952
+ self.fh.seek(seek)
953
+ self.path = path
954
+ self.size = None
955
+ self.fs = fs
956
+ self.closed = False
957
+ self.autocommit = autocommit
958
+ self.kwargs = kwargs
959
+
960
+ def __reduce__(self):
961
+ # always open in r+b to allow continuing writing at a location
962
+ return (
963
+ LocalTempFile,
964
+ (self.fs, self.path, self.fn, "r+b", self.autocommit, self.tell()),
965
+ )
966
+
967
+ def __enter__(self):
968
+ return self.fh
969
+
970
+ def __exit__(self, exc_type, exc_val, exc_tb):
971
+ self.close()
972
+
973
+ def close(self):
974
+ # self.size = self.fh.tell()
975
+ if self.closed:
976
+ return
977
+ self.fh.close()
978
+ self.closed = True
979
+ if self.autocommit:
980
+ self.commit()
981
+
982
+ def discard(self):
983
+ self.fh.close()
984
+ os.remove(self.fn)
985
+
986
+ def commit(self):
987
+ self.fs.put(self.fn, self.path, **self.kwargs)
988
+ # we do not delete the local copy, it's still in the cache.
989
+
990
+ @property
991
+ def name(self):
992
+ return self.fn
993
+
994
+ def __repr__(self) -> str:
995
+ return f"LocalTempFile: {self.path}"
996
+
997
+ def __getattr__(self, item):
998
+ return getattr(self.fh, item)
pythonProject/.venv/Lib/site-packages/fsspec/implementations/dask.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dask
2
+ from distributed.client import Client, _get_global_client
3
+ from distributed.worker import Worker
4
+
5
+ from fsspec import filesystem
6
+ from fsspec.spec import AbstractBufferedFile, AbstractFileSystem
7
+ from fsspec.utils import infer_storage_options
8
+
9
+
10
+ def _get_client(client):
11
+ if client is None:
12
+ return _get_global_client()
13
+ elif isinstance(client, Client):
14
+ return client
15
+ else:
16
+ # e.g., connection string
17
+ return Client(client)
18
+
19
+
20
+ def _in_worker():
21
+ return bool(Worker._instances)
22
+
23
+
24
+ class DaskWorkerFileSystem(AbstractFileSystem):
25
+ """View files accessible to a worker as any other remote file-system
26
+
27
+ When instances are run on the worker, uses the real filesystem. When
28
+ run on the client, they call the worker to provide information or data.
29
+
30
+ **Warning** this implementation is experimental, and read-only for now.
31
+ """
32
+
33
+ def __init__(
34
+ self, target_protocol=None, target_options=None, fs=None, client=None, **kwargs
35
+ ):
36
+ super().__init__(**kwargs)
37
+ if not (fs is None) ^ (target_protocol is None):
38
+ raise ValueError(
39
+ "Please provide one of filesystem instance (fs) or"
40
+ " target_protocol, not both"
41
+ )
42
+ self.target_protocol = target_protocol
43
+ self.target_options = target_options
44
+ self.worker = None
45
+ self.client = client
46
+ self.fs = fs
47
+ self._determine_worker()
48
+
49
+ @staticmethod
50
+ def _get_kwargs_from_urls(path):
51
+ so = infer_storage_options(path)
52
+ if "host" in so and "port" in so:
53
+ return {"client": f"{so['host']}:{so['port']}"}
54
+ else:
55
+ return {}
56
+
57
+ def _determine_worker(self):
58
+ if _in_worker():
59
+ self.worker = True
60
+ if self.fs is None:
61
+ self.fs = filesystem(
62
+ self.target_protocol, **(self.target_options or {})
63
+ )
64
+ else:
65
+ self.worker = False
66
+ self.client = _get_client(self.client)
67
+ self.rfs = dask.delayed(self)
68
+
69
+ def mkdir(self, *args, **kwargs):
70
+ if self.worker:
71
+ self.fs.mkdir(*args, **kwargs)
72
+ else:
73
+ self.rfs.mkdir(*args, **kwargs).compute()
74
+
75
+ def rm(self, *args, **kwargs):
76
+ if self.worker:
77
+ self.fs.rm(*args, **kwargs)
78
+ else:
79
+ self.rfs.rm(*args, **kwargs).compute()
80
+
81
+ def copy(self, *args, **kwargs):
82
+ if self.worker:
83
+ self.fs.copy(*args, **kwargs)
84
+ else:
85
+ self.rfs.copy(*args, **kwargs).compute()
86
+
87
+ def mv(self, *args, **kwargs):
88
+ if self.worker:
89
+ self.fs.mv(*args, **kwargs)
90
+ else:
91
+ self.rfs.mv(*args, **kwargs).compute()
92
+
93
+ def ls(self, *args, **kwargs):
94
+ if self.worker:
95
+ return self.fs.ls(*args, **kwargs)
96
+ else:
97
+ return self.rfs.ls(*args, **kwargs).compute()
98
+
99
+ def _open(
100
+ self,
101
+ path,
102
+ mode="rb",
103
+ block_size=None,
104
+ autocommit=True,
105
+ cache_options=None,
106
+ **kwargs,
107
+ ):
108
+ if self.worker:
109
+ return self.fs._open(
110
+ path,
111
+ mode=mode,
112
+ block_size=block_size,
113
+ autocommit=autocommit,
114
+ cache_options=cache_options,
115
+ **kwargs,
116
+ )
117
+ else:
118
+ return DaskFile(
119
+ fs=self,
120
+ path=path,
121
+ mode=mode,
122
+ block_size=block_size,
123
+ autocommit=autocommit,
124
+ cache_options=cache_options,
125
+ **kwargs,
126
+ )
127
+
128
+ def fetch_range(self, path, mode, start, end):
129
+ if self.worker:
130
+ with self._open(path, mode) as f:
131
+ f.seek(start)
132
+ return f.read(end - start)
133
+ else:
134
+ return self.rfs.fetch_range(path, mode, start, end).compute()
135
+
136
+
137
+ class DaskFile(AbstractBufferedFile):
138
+ def __init__(self, mode="rb", **kwargs):
139
+ if mode != "rb":
140
+ raise ValueError('Remote dask files can only be opened in "rb" mode')
141
+ super().__init__(**kwargs)
142
+
143
+ def _upload_chunk(self, final=False):
144
+ pass
145
+
146
+ def _initiate_upload(self):
147
+ """Create remote file/upload"""
148
+ pass
149
+
150
+ def _fetch_range(self, start, end):
151
+ """Get the specified set of bytes from remote"""
152
+ return self.fs.fetch_range(self.path, self.mode, start, end)
pythonProject/.venv/Lib/site-packages/fsspec/implementations/data.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ from typing import Optional
4
+ from urllib.parse import unquote
5
+
6
+ from fsspec import AbstractFileSystem
7
+
8
+
9
+ class DataFileSystem(AbstractFileSystem):
10
+ """A handy decoder for data-URLs
11
+
12
+ Example
13
+ -------
14
+ >>> with fsspec.open("data:,Hello%2C%20World%21") as f:
15
+ ... print(f.read())
16
+ b"Hello, World!"
17
+
18
+ See https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs
19
+ """
20
+
21
+ protocol = "data"
22
+
23
+ def __init__(self, **kwargs):
24
+ """No parameters for this filesystem"""
25
+ super().__init__(**kwargs)
26
+
27
+ def cat_file(self, path, start=None, end=None, **kwargs):
28
+ pref, data = path.split(",", 1)
29
+ if pref.endswith("base64"):
30
+ return base64.b64decode(data)[start:end]
31
+ return unquote(data).encode()[start:end]
32
+
33
+ def info(self, path, **kwargs):
34
+ pref, name = path.split(",", 1)
35
+ data = self.cat_file(path)
36
+ mime = pref.split(":", 1)[1].split(";", 1)[0]
37
+ return {"name": name, "size": len(data), "type": "file", "mimetype": mime}
38
+
39
+ def _open(
40
+ self,
41
+ path,
42
+ mode="rb",
43
+ block_size=None,
44
+ autocommit=True,
45
+ cache_options=None,
46
+ **kwargs,
47
+ ):
48
+ if "r" not in mode:
49
+ raise ValueError("Read only filesystem")
50
+ return io.BytesIO(self.cat_file(path))
51
+
52
+ @staticmethod
53
+ def encode(data: bytes, mime: Optional[str] = None):
54
+ """Format the given data into data-URL syntax
55
+
56
+ This version always base64 encodes, even when the data is ascii/url-safe.
57
+ """
58
+ return f"data:{mime or ''};base64,{base64.b64encode(data).decode()}"
pythonProject/.venv/Lib/site-packages/fsspec/implementations/dbfs.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import urllib
5
+
6
+ import requests
7
+ from requests.adapters import HTTPAdapter, Retry
8
+ from typing_extensions import override
9
+
10
+ from fsspec import AbstractFileSystem
11
+ from fsspec.spec import AbstractBufferedFile
12
+
13
+
14
+ class DatabricksException(Exception):
15
+ """
16
+ Helper class for exceptions raised in this module.
17
+ """
18
+
19
+ def __init__(self, error_code, message, details=None):
20
+ """Create a new DatabricksException"""
21
+ super().__init__(message)
22
+
23
+ self.error_code = error_code
24
+ self.message = message
25
+ self.details = details
26
+
27
+
28
+ class DatabricksFileSystem(AbstractFileSystem):
29
+ """
30
+ Get access to the Databricks filesystem implementation over HTTP.
31
+ Can be used inside and outside of a databricks cluster.
32
+ """
33
+
34
+ def __init__(self, instance, token, **kwargs):
35
+ """
36
+ Create a new DatabricksFileSystem.
37
+
38
+ Parameters
39
+ ----------
40
+ instance: str
41
+ The instance URL of the databricks cluster.
42
+ For example for an Azure databricks cluster, this
43
+ has the form adb-<some-number>.<two digits>.azuredatabricks.net.
44
+ token: str
45
+ Your personal token. Find out more
46
+ here: https://docs.databricks.com/dev-tools/api/latest/authentication.html
47
+ """
48
+ self.instance = instance
49
+ self.token = token
50
+ self.session = requests.Session()
51
+ self.retries = Retry(
52
+ total=10,
53
+ backoff_factor=0.05,
54
+ status_forcelist=[408, 429, 500, 502, 503, 504],
55
+ )
56
+
57
+ self.session.mount("https://", HTTPAdapter(max_retries=self.retries))
58
+ self.session.headers.update({"Authorization": f"Bearer {self.token}"})
59
+
60
+ super().__init__(**kwargs)
61
+
62
+ @override
63
+ def _ls_from_cache(self, path) -> list[dict[str, str | int]] | None:
64
+ """Check cache for listing
65
+
66
+ Returns listing, if found (may be empty list for a directory that
67
+ exists but contains nothing), None if not in cache.
68
+ """
69
+ self.dircache.pop(path.rstrip("/"), None)
70
+
71
+ parent = self._parent(path)
72
+ if parent in self.dircache:
73
+ for entry in self.dircache[parent]:
74
+ if entry["name"] == path.rstrip("/"):
75
+ if entry["type"] != "directory":
76
+ return [entry]
77
+ return []
78
+ raise FileNotFoundError(path)
79
+
80
+ def ls(self, path, detail=True, **kwargs):
81
+ """
82
+ List the contents of the given path.
83
+
84
+ Parameters
85
+ ----------
86
+ path: str
87
+ Absolute path
88
+ detail: bool
89
+ Return not only the list of filenames,
90
+ but also additional information on file sizes
91
+ and types.
92
+ """
93
+ try:
94
+ out = self._ls_from_cache(path)
95
+ except FileNotFoundError:
96
+ # This happens if the `path`'s parent was cached, but `path` is not
97
+ # there. This suggests that `path` is new since the parent was
98
+ # cached. Attempt to invalidate parent's cache before continuing.
99
+ self.dircache.pop(self._parent(path), None)
100
+ out = None
101
+
102
+ if not out:
103
+ try:
104
+ r = self._send_to_api(
105
+ method="get", endpoint="list", json={"path": path}
106
+ )
107
+ except DatabricksException as e:
108
+ if e.error_code == "RESOURCE_DOES_NOT_EXIST":
109
+ raise FileNotFoundError(e.message) from e
110
+
111
+ raise
112
+ files = r.get("files", [])
113
+ out = [
114
+ {
115
+ "name": o["path"],
116
+ "type": "directory" if o["is_dir"] else "file",
117
+ "size": o["file_size"],
118
+ }
119
+ for o in files
120
+ ]
121
+ self.dircache[path] = out
122
+
123
+ if detail:
124
+ return out
125
+ return [o["name"] for o in out]
126
+
127
+ def makedirs(self, path, exist_ok=True):
128
+ """
129
+ Create a given absolute path and all of its parents.
130
+
131
+ Parameters
132
+ ----------
133
+ path: str
134
+ Absolute path to create
135
+ exist_ok: bool
136
+ If false, checks if the folder
137
+ exists before creating it (and raises an
138
+ Exception if this is the case)
139
+ """
140
+ if not exist_ok:
141
+ try:
142
+ # If the following succeeds, the path is already present
143
+ self._send_to_api(
144
+ method="get", endpoint="get-status", json={"path": path}
145
+ )
146
+ raise FileExistsError(f"Path {path} already exists")
147
+ except DatabricksException as e:
148
+ if e.error_code == "RESOURCE_DOES_NOT_EXIST":
149
+ pass
150
+
151
+ try:
152
+ self._send_to_api(method="post", endpoint="mkdirs", json={"path": path})
153
+ except DatabricksException as e:
154
+ if e.error_code == "RESOURCE_ALREADY_EXISTS":
155
+ raise FileExistsError(e.message) from e
156
+
157
+ raise
158
+ self.invalidate_cache(self._parent(path))
159
+
160
+ def mkdir(self, path, create_parents=True, **kwargs):
161
+ """
162
+ Create a given absolute path and all of its parents.
163
+
164
+ Parameters
165
+ ----------
166
+ path: str
167
+ Absolute path to create
168
+ create_parents: bool
169
+ Whether to create all parents or not.
170
+ "False" is not implemented so far.
171
+ """
172
+ if not create_parents:
173
+ raise NotImplementedError
174
+
175
+ self.mkdirs(path, **kwargs)
176
+
177
+ def rm(self, path, recursive=False, **kwargs):
178
+ """
179
+ Remove the file or folder at the given absolute path.
180
+
181
+ Parameters
182
+ ----------
183
+ path: str
184
+ Absolute path what to remove
185
+ recursive: bool
186
+ Recursively delete all files in a folder.
187
+ """
188
+ try:
189
+ self._send_to_api(
190
+ method="post",
191
+ endpoint="delete",
192
+ json={"path": path, "recursive": recursive},
193
+ )
194
+ except DatabricksException as e:
195
+ # This is not really an exception, it just means
196
+ # not everything was deleted so far
197
+ if e.error_code == "PARTIAL_DELETE":
198
+ self.rm(path=path, recursive=recursive)
199
+ elif e.error_code == "IO_ERROR":
200
+ # Using the same exception as the os module would use here
201
+ raise OSError(e.message) from e
202
+
203
+ raise
204
+ self.invalidate_cache(self._parent(path))
205
+
206
+ def mv(
207
+ self, source_path, destination_path, recursive=False, maxdepth=None, **kwargs
208
+ ):
209
+ """
210
+ Move a source to a destination path.
211
+
212
+ A note from the original [databricks API manual]
213
+ (https://docs.databricks.com/dev-tools/api/latest/dbfs.html#move).
214
+
215
+ When moving a large number of files the API call will time out after
216
+ approximately 60s, potentially resulting in partially moved data.
217
+ Therefore, for operations that move more than 10k files, we strongly
218
+ discourage using the DBFS REST API.
219
+
220
+ Parameters
221
+ ----------
222
+ source_path: str
223
+ From where to move (absolute path)
224
+ destination_path: str
225
+ To where to move (absolute path)
226
+ recursive: bool
227
+ Not implemented to far.
228
+ maxdepth:
229
+ Not implemented to far.
230
+ """
231
+ if recursive:
232
+ raise NotImplementedError
233
+ if maxdepth:
234
+ raise NotImplementedError
235
+
236
+ try:
237
+ self._send_to_api(
238
+ method="post",
239
+ endpoint="move",
240
+ json={"source_path": source_path, "destination_path": destination_path},
241
+ )
242
+ except DatabricksException as e:
243
+ if e.error_code == "RESOURCE_DOES_NOT_EXIST":
244
+ raise FileNotFoundError(e.message) from e
245
+ elif e.error_code == "RESOURCE_ALREADY_EXISTS":
246
+ raise FileExistsError(e.message) from e
247
+
248
+ raise
249
+ self.invalidate_cache(self._parent(source_path))
250
+ self.invalidate_cache(self._parent(destination_path))
251
+
252
+ def _open(self, path, mode="rb", block_size="default", **kwargs):
253
+ """
254
+ Overwrite the base class method to make sure to create a DBFile.
255
+ All arguments are copied from the base method.
256
+
257
+ Only the default blocksize is allowed.
258
+ """
259
+ return DatabricksFile(self, path, mode=mode, block_size=block_size, **kwargs)
260
+
261
+ def _send_to_api(self, method, endpoint, json):
262
+ """
263
+ Send the given json to the DBFS API
264
+ using a get or post request (specified by the argument `method`).
265
+
266
+ Parameters
267
+ ----------
268
+ method: str
269
+ Which http method to use for communication; "get" or "post".
270
+ endpoint: str
271
+ Where to send the request to (last part of the API URL)
272
+ json: dict
273
+ Dictionary of information to send
274
+ """
275
+ if method == "post":
276
+ session_call = self.session.post
277
+ elif method == "get":
278
+ session_call = self.session.get
279
+ else:
280
+ raise ValueError(f"Do not understand method {method}")
281
+
282
+ url = urllib.parse.urljoin(f"https://{self.instance}/api/2.0/dbfs/", endpoint)
283
+
284
+ r = session_call(url, json=json)
285
+
286
+ # The DBFS API will return a json, also in case of an exception.
287
+ # We want to preserve this information as good as possible.
288
+ try:
289
+ r.raise_for_status()
290
+ except requests.HTTPError as e:
291
+ # try to extract json error message
292
+ # if that fails, fall back to the original exception
293
+ try:
294
+ exception_json = e.response.json()
295
+ except Exception:
296
+ raise e from None
297
+
298
+ raise DatabricksException(**exception_json) from e
299
+
300
+ return r.json()
301
+
302
+ def _create_handle(self, path, overwrite=True):
303
+ """
304
+ Internal function to create a handle, which can be used to
305
+ write blocks of a file to DBFS.
306
+ A handle has a unique identifier which needs to be passed
307
+ whenever written during this transaction.
308
+ The handle is active for 10 minutes - after that a new
309
+ write transaction needs to be created.
310
+ Make sure to close the handle after you are finished.
311
+
312
+ Parameters
313
+ ----------
314
+ path: str
315
+ Absolute path for this file.
316
+ overwrite: bool
317
+ If a file already exist at this location, either overwrite
318
+ it or raise an exception.
319
+ """
320
+ try:
321
+ r = self._send_to_api(
322
+ method="post",
323
+ endpoint="create",
324
+ json={"path": path, "overwrite": overwrite},
325
+ )
326
+ return r["handle"]
327
+ except DatabricksException as e:
328
+ if e.error_code == "RESOURCE_ALREADY_EXISTS":
329
+ raise FileExistsError(e.message) from e
330
+
331
+ raise
332
+
333
+ def _close_handle(self, handle):
334
+ """
335
+ Close a handle, which was opened by :func:`_create_handle`.
336
+
337
+ Parameters
338
+ ----------
339
+ handle: str
340
+ Which handle to close.
341
+ """
342
+ try:
343
+ self._send_to_api(method="post", endpoint="close", json={"handle": handle})
344
+ except DatabricksException as e:
345
+ if e.error_code == "RESOURCE_DOES_NOT_EXIST":
346
+ raise FileNotFoundError(e.message) from e
347
+
348
+ raise
349
+
350
+ def _add_data(self, handle, data):
351
+ """
352
+ Upload data to an already opened file handle
353
+ (opened by :func:`_create_handle`).
354
+ The maximal allowed data size is 1MB after
355
+ conversion to base64.
356
+ Remember to close the handle when you are finished.
357
+
358
+ Parameters
359
+ ----------
360
+ handle: str
361
+ Which handle to upload data to.
362
+ data: bytes
363
+ Block of data to add to the handle.
364
+ """
365
+ data = base64.b64encode(data).decode()
366
+ try:
367
+ self._send_to_api(
368
+ method="post",
369
+ endpoint="add-block",
370
+ json={"handle": handle, "data": data},
371
+ )
372
+ except DatabricksException as e:
373
+ if e.error_code == "RESOURCE_DOES_NOT_EXIST":
374
+ raise FileNotFoundError(e.message) from e
375
+ elif e.error_code == "MAX_BLOCK_SIZE_EXCEEDED":
376
+ raise ValueError(e.message) from e
377
+
378
+ raise
379
+
380
+ def _get_data(self, path, start, end):
381
+ """
382
+ Download data in bytes from a given absolute path in a block
383
+ from [start, start+length].
384
+ The maximum number of allowed bytes to read is 1MB.
385
+
386
+ Parameters
387
+ ----------
388
+ path: str
389
+ Absolute path to download data from
390
+ start: int
391
+ Start position of the block
392
+ end: int
393
+ End position of the block
394
+ """
395
+ try:
396
+ r = self._send_to_api(
397
+ method="get",
398
+ endpoint="read",
399
+ json={"path": path, "offset": start, "length": end - start},
400
+ )
401
+ return base64.b64decode(r["data"])
402
+ except DatabricksException as e:
403
+ if e.error_code == "RESOURCE_DOES_NOT_EXIST":
404
+ raise FileNotFoundError(e.message) from e
405
+ elif e.error_code in ["INVALID_PARAMETER_VALUE", "MAX_READ_SIZE_EXCEEDED"]:
406
+ raise ValueError(e.message) from e
407
+
408
+ raise
409
+
410
+ def invalidate_cache(self, path=None):
411
+ if path is None:
412
+ self.dircache.clear()
413
+ else:
414
+ self.dircache.pop(path, None)
415
+ super().invalidate_cache(path)
416
+
417
+
418
+ class DatabricksFile(AbstractBufferedFile):
419
+ """
420
+ Helper class for files referenced in the DatabricksFileSystem.
421
+ """
422
+
423
+ DEFAULT_BLOCK_SIZE = 1 * 2**20 # only allowed block size
424
+
425
+ def __init__(
426
+ self,
427
+ fs,
428
+ path,
429
+ mode="rb",
430
+ block_size="default",
431
+ autocommit=True,
432
+ cache_type="readahead",
433
+ cache_options=None,
434
+ **kwargs,
435
+ ):
436
+ """
437
+ Create a new instance of the DatabricksFile.
438
+
439
+ The blocksize needs to be the default one.
440
+ """
441
+ if block_size is None or block_size == "default":
442
+ block_size = self.DEFAULT_BLOCK_SIZE
443
+
444
+ assert block_size == self.DEFAULT_BLOCK_SIZE, (
445
+ f"Only the default block size is allowed, not {block_size}"
446
+ )
447
+
448
+ super().__init__(
449
+ fs,
450
+ path,
451
+ mode=mode,
452
+ block_size=block_size,
453
+ autocommit=autocommit,
454
+ cache_type=cache_type,
455
+ cache_options=cache_options or {},
456
+ **kwargs,
457
+ )
458
+
459
+ def _initiate_upload(self):
460
+ """Internal function to start a file upload"""
461
+ self.handle = self.fs._create_handle(self.path)
462
+
463
+ def _upload_chunk(self, final=False):
464
+ """Internal function to add a chunk of data to a started upload"""
465
+ self.buffer.seek(0)
466
+ data = self.buffer.getvalue()
467
+
468
+ data_chunks = [
469
+ data[start:end] for start, end in self._to_sized_blocks(len(data))
470
+ ]
471
+
472
+ for data_chunk in data_chunks:
473
+ self.fs._add_data(handle=self.handle, data=data_chunk)
474
+
475
+ if final:
476
+ self.fs._close_handle(handle=self.handle)
477
+ return True
478
+
479
+ def _fetch_range(self, start, end):
480
+ """Internal function to download a block of data"""
481
+ return_buffer = b""
482
+ length = end - start
483
+ for chunk_start, chunk_end in self._to_sized_blocks(length, start):
484
+ return_buffer += self.fs._get_data(
485
+ path=self.path, start=chunk_start, end=chunk_end
486
+ )
487
+
488
+ return return_buffer
489
+
490
+ def _to_sized_blocks(self, length, start=0):
491
+ """Helper function to split a range from 0 to total_length into blocksizes"""
492
+ end = start + length
493
+ for data_chunk in range(start, end, self.blocksize):
494
+ data_start = data_chunk
495
+ data_end = min(end, data_chunk + self.blocksize)
496
+ yield data_start, data_end
pythonProject/.venv/Lib/site-packages/fsspec/implementations/dirfs.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .. import filesystem
2
+ from ..asyn import AsyncFileSystem
3
+
4
+
5
+ class DirFileSystem(AsyncFileSystem):
6
+ """Directory prefix filesystem
7
+
8
+ The DirFileSystem is a filesystem-wrapper. It assumes every path it is dealing with
9
+ is relative to the `path`. After performing the necessary paths operation it
10
+ delegates everything to the wrapped filesystem.
11
+ """
12
+
13
+ protocol = "dir"
14
+
15
+ def __init__(
16
+ self,
17
+ path=None,
18
+ fs=None,
19
+ fo=None,
20
+ target_protocol=None,
21
+ target_options=None,
22
+ **storage_options,
23
+ ):
24
+ """
25
+ Parameters
26
+ ----------
27
+ path: str
28
+ Path to the directory.
29
+ fs: AbstractFileSystem
30
+ An instantiated filesystem to wrap.
31
+ target_protocol, target_options:
32
+ if fs is none, construct it from these
33
+ fo: str
34
+ Alternate for path; do not provide both
35
+ """
36
+ super().__init__(**storage_options)
37
+ if fs is None:
38
+ fs = filesystem(protocol=target_protocol, **(target_options or {}))
39
+ path = path or fo
40
+
41
+ if self.asynchronous and not fs.async_impl:
42
+ raise ValueError("can't use asynchronous with non-async fs")
43
+
44
+ if fs.async_impl and self.asynchronous != fs.asynchronous:
45
+ raise ValueError("both dirfs and fs should be in the same sync/async mode")
46
+
47
+ self.path = fs._strip_protocol(path)
48
+ self.fs = fs
49
+
50
+ def _join(self, path):
51
+ if isinstance(path, str):
52
+ if not self.path:
53
+ return path
54
+ if not path:
55
+ return self.path
56
+ return self.fs.sep.join((self.path, self._strip_protocol(path)))
57
+ if isinstance(path, dict):
58
+ return {self._join(_path): value for _path, value in path.items()}
59
+ return [self._join(_path) for _path in path]
60
+
61
+ def _relpath(self, path):
62
+ if isinstance(path, str):
63
+ if not self.path:
64
+ return path
65
+ # We need to account for S3FileSystem returning paths that do not
66
+ # start with a '/'
67
+ if path == self.path or (
68
+ self.path.startswith(self.fs.sep) and path == self.path[1:]
69
+ ):
70
+ return ""
71
+ prefix = self.path + self.fs.sep
72
+ if self.path.startswith(self.fs.sep) and not path.startswith(self.fs.sep):
73
+ prefix = prefix[1:]
74
+ assert path.startswith(prefix)
75
+ return path[len(prefix) :]
76
+ return [self._relpath(_path) for _path in path]
77
+
78
+ # Wrappers below
79
+
80
+ @property
81
+ def sep(self):
82
+ return self.fs.sep
83
+
84
+ async def set_session(self, *args, **kwargs):
85
+ return await self.fs.set_session(*args, **kwargs)
86
+
87
+ async def _rm_file(self, path, **kwargs):
88
+ return await self.fs._rm_file(self._join(path), **kwargs)
89
+
90
+ def rm_file(self, path, **kwargs):
91
+ return self.fs.rm_file(self._join(path), **kwargs)
92
+
93
+ async def _rm(self, path, *args, **kwargs):
94
+ return await self.fs._rm(self._join(path), *args, **kwargs)
95
+
96
+ def rm(self, path, *args, **kwargs):
97
+ return self.fs.rm(self._join(path), *args, **kwargs)
98
+
99
+ async def _cp_file(self, path1, path2, **kwargs):
100
+ return await self.fs._cp_file(self._join(path1), self._join(path2), **kwargs)
101
+
102
+ def cp_file(self, path1, path2, **kwargs):
103
+ return self.fs.cp_file(self._join(path1), self._join(path2), **kwargs)
104
+
105
+ async def _copy(
106
+ self,
107
+ path1,
108
+ path2,
109
+ *args,
110
+ **kwargs,
111
+ ):
112
+ return await self.fs._copy(
113
+ self._join(path1),
114
+ self._join(path2),
115
+ *args,
116
+ **kwargs,
117
+ )
118
+
119
+ def copy(self, path1, path2, *args, **kwargs):
120
+ return self.fs.copy(
121
+ self._join(path1),
122
+ self._join(path2),
123
+ *args,
124
+ **kwargs,
125
+ )
126
+
127
+ async def _pipe(self, path, *args, **kwargs):
128
+ return await self.fs._pipe(self._join(path), *args, **kwargs)
129
+
130
+ def pipe(self, path, *args, **kwargs):
131
+ return self.fs.pipe(self._join(path), *args, **kwargs)
132
+
133
+ async def _pipe_file(self, path, *args, **kwargs):
134
+ return await self.fs._pipe_file(self._join(path), *args, **kwargs)
135
+
136
+ def pipe_file(self, path, *args, **kwargs):
137
+ return self.fs.pipe_file(self._join(path), *args, **kwargs)
138
+
139
+ async def _cat_file(self, path, *args, **kwargs):
140
+ return await self.fs._cat_file(self._join(path), *args, **kwargs)
141
+
142
+ def cat_file(self, path, *args, **kwargs):
143
+ return self.fs.cat_file(self._join(path), *args, **kwargs)
144
+
145
+ async def _cat(self, path, *args, **kwargs):
146
+ ret = await self.fs._cat(
147
+ self._join(path),
148
+ *args,
149
+ **kwargs,
150
+ )
151
+
152
+ if isinstance(ret, dict):
153
+ return {self._relpath(key): value for key, value in ret.items()}
154
+
155
+ return ret
156
+
157
+ def cat(self, path, *args, **kwargs):
158
+ ret = self.fs.cat(
159
+ self._join(path),
160
+ *args,
161
+ **kwargs,
162
+ )
163
+
164
+ if isinstance(ret, dict):
165
+ return {self._relpath(key): value for key, value in ret.items()}
166
+
167
+ return ret
168
+
169
+ async def _put_file(self, lpath, rpath, **kwargs):
170
+ return await self.fs._put_file(lpath, self._join(rpath), **kwargs)
171
+
172
+ def put_file(self, lpath, rpath, **kwargs):
173
+ return self.fs.put_file(lpath, self._join(rpath), **kwargs)
174
+
175
+ async def _put(
176
+ self,
177
+ lpath,
178
+ rpath,
179
+ *args,
180
+ **kwargs,
181
+ ):
182
+ return await self.fs._put(
183
+ lpath,
184
+ self._join(rpath),
185
+ *args,
186
+ **kwargs,
187
+ )
188
+
189
+ def put(self, lpath, rpath, *args, **kwargs):
190
+ return self.fs.put(
191
+ lpath,
192
+ self._join(rpath),
193
+ *args,
194
+ **kwargs,
195
+ )
196
+
197
+ async def _get_file(self, rpath, lpath, **kwargs):
198
+ return await self.fs._get_file(self._join(rpath), lpath, **kwargs)
199
+
200
+ def get_file(self, rpath, lpath, **kwargs):
201
+ return self.fs.get_file(self._join(rpath), lpath, **kwargs)
202
+
203
+ async def _get(self, rpath, *args, **kwargs):
204
+ return await self.fs._get(self._join(rpath), *args, **kwargs)
205
+
206
+ def get(self, rpath, *args, **kwargs):
207
+ return self.fs.get(self._join(rpath), *args, **kwargs)
208
+
209
+ async def _isfile(self, path):
210
+ return await self.fs._isfile(self._join(path))
211
+
212
+ def isfile(self, path):
213
+ return self.fs.isfile(self._join(path))
214
+
215
+ async def _isdir(self, path):
216
+ return await self.fs._isdir(self._join(path))
217
+
218
+ def isdir(self, path):
219
+ return self.fs.isdir(self._join(path))
220
+
221
+ async def _size(self, path):
222
+ return await self.fs._size(self._join(path))
223
+
224
+ def size(self, path):
225
+ return self.fs.size(self._join(path))
226
+
227
+ async def _exists(self, path):
228
+ return await self.fs._exists(self._join(path))
229
+
230
+ def exists(self, path):
231
+ return self.fs.exists(self._join(path))
232
+
233
+ async def _info(self, path, **kwargs):
234
+ info = await self.fs._info(self._join(path), **kwargs)
235
+ info = info.copy()
236
+ info["name"] = self._relpath(info["name"])
237
+ return info
238
+
239
+ def info(self, path, **kwargs):
240
+ info = self.fs.info(self._join(path), **kwargs)
241
+ info = info.copy()
242
+ info["name"] = self._relpath(info["name"])
243
+ return info
244
+
245
+ async def _ls(self, path, detail=True, **kwargs):
246
+ ret = (await self.fs._ls(self._join(path), detail=detail, **kwargs)).copy()
247
+ if detail:
248
+ out = []
249
+ for entry in ret:
250
+ entry = entry.copy()
251
+ entry["name"] = self._relpath(entry["name"])
252
+ out.append(entry)
253
+ return out
254
+
255
+ return self._relpath(ret)
256
+
257
+ def ls(self, path, detail=True, **kwargs):
258
+ ret = self.fs.ls(self._join(path), detail=detail, **kwargs).copy()
259
+ if detail:
260
+ out = []
261
+ for entry in ret:
262
+ entry = entry.copy()
263
+ entry["name"] = self._relpath(entry["name"])
264
+ out.append(entry)
265
+ return out
266
+
267
+ return self._relpath(ret)
268
+
269
+ async def _walk(self, path, *args, **kwargs):
270
+ async for root, dirs, files in self.fs._walk(self._join(path), *args, **kwargs):
271
+ yield self._relpath(root), dirs, files
272
+
273
+ def walk(self, path, *args, **kwargs):
274
+ for root, dirs, files in self.fs.walk(self._join(path), *args, **kwargs):
275
+ yield self._relpath(root), dirs, files
276
+
277
+ async def _glob(self, path, **kwargs):
278
+ detail = kwargs.get("detail", False)
279
+ ret = await self.fs._glob(self._join(path), **kwargs)
280
+ if detail:
281
+ return {self._relpath(path): info for path, info in ret.items()}
282
+ return self._relpath(ret)
283
+
284
+ def glob(self, path, **kwargs):
285
+ detail = kwargs.get("detail", False)
286
+ ret = self.fs.glob(self._join(path), **kwargs)
287
+ if detail:
288
+ return {self._relpath(path): info for path, info in ret.items()}
289
+ return self._relpath(ret)
290
+
291
+ async def _du(self, path, *args, **kwargs):
292
+ total = kwargs.get("total", True)
293
+ ret = await self.fs._du(self._join(path), *args, **kwargs)
294
+ if total:
295
+ return ret
296
+
297
+ return {self._relpath(path): size for path, size in ret.items()}
298
+
299
+ def du(self, path, *args, **kwargs):
300
+ total = kwargs.get("total", True)
301
+ ret = self.fs.du(self._join(path), *args, **kwargs)
302
+ if total:
303
+ return ret
304
+
305
+ return {self._relpath(path): size for path, size in ret.items()}
306
+
307
+ async def _find(self, path, *args, **kwargs):
308
+ detail = kwargs.get("detail", False)
309
+ ret = await self.fs._find(self._join(path), *args, **kwargs)
310
+ if detail:
311
+ return {self._relpath(path): info for path, info in ret.items()}
312
+ return self._relpath(ret)
313
+
314
+ def find(self, path, *args, **kwargs):
315
+ detail = kwargs.get("detail", False)
316
+ ret = self.fs.find(self._join(path), *args, **kwargs)
317
+ if detail:
318
+ return {self._relpath(path): info for path, info in ret.items()}
319
+ return self._relpath(ret)
320
+
321
+ async def _expand_path(self, path, *args, **kwargs):
322
+ return self._relpath(
323
+ await self.fs._expand_path(self._join(path), *args, **kwargs)
324
+ )
325
+
326
+ def expand_path(self, path, *args, **kwargs):
327
+ return self._relpath(self.fs.expand_path(self._join(path), *args, **kwargs))
328
+
329
+ async def _mkdir(self, path, *args, **kwargs):
330
+ return await self.fs._mkdir(self._join(path), *args, **kwargs)
331
+
332
+ def mkdir(self, path, *args, **kwargs):
333
+ return self.fs.mkdir(self._join(path), *args, **kwargs)
334
+
335
+ async def _makedirs(self, path, *args, **kwargs):
336
+ return await self.fs._makedirs(self._join(path), *args, **kwargs)
337
+
338
+ def makedirs(self, path, *args, **kwargs):
339
+ return self.fs.makedirs(self._join(path), *args, **kwargs)
340
+
341
+ def rmdir(self, path):
342
+ return self.fs.rmdir(self._join(path))
343
+
344
+ def mv(self, path1, path2, **kwargs):
345
+ return self.fs.mv(
346
+ self._join(path1),
347
+ self._join(path2),
348
+ **kwargs,
349
+ )
350
+
351
+ def touch(self, path, **kwargs):
352
+ return self.fs.touch(self._join(path), **kwargs)
353
+
354
+ def created(self, path):
355
+ return self.fs.created(self._join(path))
356
+
357
+ def modified(self, path):
358
+ return self.fs.modified(self._join(path))
359
+
360
+ def sign(self, path, *args, **kwargs):
361
+ return self.fs.sign(self._join(path), *args, **kwargs)
362
+
363
+ def __repr__(self):
364
+ return f"{self.__class__.__qualname__}(path='{self.path}', fs={self.fs})"
365
+
366
+ def open(
367
+ self,
368
+ path,
369
+ *args,
370
+ **kwargs,
371
+ ):
372
+ return self.fs.open(
373
+ self._join(path),
374
+ *args,
375
+ **kwargs,
376
+ )
377
+
378
+ async def open_async(
379
+ self,
380
+ path,
381
+ *args,
382
+ **kwargs,
383
+ ):
384
+ return await self.fs.open_async(
385
+ self._join(path),
386
+ *args,
387
+ **kwargs,
388
+ )
pythonProject/.venv/Lib/site-packages/fsspec/utils.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import logging
5
+ import math
6
+ import os
7
+ import re
8
+ import sys
9
+ import tempfile
10
+ from collections.abc import Iterable, Iterator, Sequence
11
+ from functools import partial
12
+ from hashlib import md5
13
+ from importlib.metadata import version
14
+ from typing import (
15
+ IO,
16
+ TYPE_CHECKING,
17
+ Any,
18
+ Callable,
19
+ TypeVar,
20
+ )
21
+ from urllib.parse import urlsplit
22
+
23
+ if TYPE_CHECKING:
24
+ import pathlib
25
+
26
+ from typing_extensions import TypeGuard
27
+
28
+ from fsspec.spec import AbstractFileSystem
29
+
30
+
31
+ DEFAULT_BLOCK_SIZE = 5 * 2**20
32
+
33
+ T = TypeVar("T")
34
+
35
+
36
+ def infer_storage_options(
37
+ urlpath: str, inherit_storage_options: dict[str, Any] | None = None
38
+ ) -> dict[str, Any]:
39
+ """Infer storage options from URL path and merge it with existing storage
40
+ options.
41
+
42
+ Parameters
43
+ ----------
44
+ urlpath: str or unicode
45
+ Either local absolute file path or URL (hdfs://namenode:8020/file.csv)
46
+ inherit_storage_options: dict (optional)
47
+ Its contents will get merged with the inferred information from the
48
+ given path
49
+
50
+ Returns
51
+ -------
52
+ Storage options dict.
53
+
54
+ Examples
55
+ --------
56
+ >>> infer_storage_options('/mnt/datasets/test.csv') # doctest: +SKIP
57
+ {"protocol": "file", "path", "/mnt/datasets/test.csv"}
58
+ >>> infer_storage_options(
59
+ ... 'hdfs://username:pwd@node:123/mnt/datasets/test.csv?q=1',
60
+ ... inherit_storage_options={'extra': 'value'},
61
+ ... ) # doctest: +SKIP
62
+ {"protocol": "hdfs", "username": "username", "password": "pwd",
63
+ "host": "node", "port": 123, "path": "/mnt/datasets/test.csv",
64
+ "url_query": "q=1", "extra": "value"}
65
+ """
66
+ # Handle Windows paths including disk name in this special case
67
+ if (
68
+ re.match(r"^[a-zA-Z]:[\\/]", urlpath)
69
+ or re.match(r"^[a-zA-Z0-9]+://", urlpath) is None
70
+ ):
71
+ return {"protocol": "file", "path": urlpath}
72
+
73
+ parsed_path = urlsplit(urlpath)
74
+ protocol = parsed_path.scheme or "file"
75
+ if parsed_path.fragment:
76
+ path = "#".join([parsed_path.path, parsed_path.fragment])
77
+ else:
78
+ path = parsed_path.path
79
+ if protocol == "file":
80
+ # Special case parsing file protocol URL on Windows according to:
81
+ # https://msdn.microsoft.com/en-us/library/jj710207.aspx
82
+ windows_path = re.match(r"^/([a-zA-Z])[:|]([\\/].*)$", path)
83
+ if windows_path:
84
+ drive, path = windows_path.groups()
85
+ path = f"{drive}:{path}"
86
+
87
+ if protocol in ["http", "https"]:
88
+ # for HTTP, we don't want to parse, as requests will anyway
89
+ return {"protocol": protocol, "path": urlpath}
90
+
91
+ options: dict[str, Any] = {"protocol": protocol, "path": path}
92
+
93
+ if parsed_path.netloc:
94
+ # Parse `hostname` from netloc manually because `parsed_path.hostname`
95
+ # lowercases the hostname which is not always desirable (e.g. in S3):
96
+ # https://github.com/dask/dask/issues/1417
97
+ options["host"] = parsed_path.netloc.rsplit("@", 1)[-1].rsplit(":", 1)[0]
98
+
99
+ if protocol in ("s3", "s3a", "gcs", "gs"):
100
+ options["path"] = options["host"] + options["path"]
101
+ else:
102
+ options["host"] = options["host"]
103
+ if parsed_path.port:
104
+ options["port"] = parsed_path.port
105
+ if parsed_path.username:
106
+ options["username"] = parsed_path.username
107
+ if parsed_path.password:
108
+ options["password"] = parsed_path.password
109
+
110
+ if parsed_path.query:
111
+ options["url_query"] = parsed_path.query
112
+ if parsed_path.fragment:
113
+ options["url_fragment"] = parsed_path.fragment
114
+
115
+ if inherit_storage_options:
116
+ update_storage_options(options, inherit_storage_options)
117
+
118
+ return options
119
+
120
+
121
+ def update_storage_options(
122
+ options: dict[str, Any], inherited: dict[str, Any] | None = None
123
+ ) -> None:
124
+ if not inherited:
125
+ inherited = {}
126
+ collisions = set(options) & set(inherited)
127
+ if collisions:
128
+ for collision in collisions:
129
+ if options.get(collision) != inherited.get(collision):
130
+ raise KeyError(
131
+ f"Collision between inferred and specified storage "
132
+ f"option:\n{collision}"
133
+ )
134
+ options.update(inherited)
135
+
136
+
137
+ # Compression extensions registered via fsspec.compression.register_compression
138
+ compressions: dict[str, str] = {}
139
+
140
+
141
+ def infer_compression(filename: str) -> str | None:
142
+ """Infer compression, if available, from filename.
143
+
144
+ Infer a named compression type, if registered and available, from filename
145
+ extension. This includes builtin (gz, bz2, zip) compressions, as well as
146
+ optional compressions. See fsspec.compression.register_compression.
147
+ """
148
+ extension = os.path.splitext(filename)[-1].strip(".").lower()
149
+ if extension in compressions:
150
+ return compressions[extension]
151
+ return None
152
+
153
+
154
+ def build_name_function(max_int: float) -> Callable[[int], str]:
155
+ """Returns a function that receives a single integer
156
+ and returns it as a string padded by enough zero characters
157
+ to align with maximum possible integer
158
+
159
+ >>> name_f = build_name_function(57)
160
+
161
+ >>> name_f(7)
162
+ '07'
163
+ >>> name_f(31)
164
+ '31'
165
+ >>> build_name_function(1000)(42)
166
+ '0042'
167
+ >>> build_name_function(999)(42)
168
+ '042'
169
+ >>> build_name_function(0)(0)
170
+ '0'
171
+ """
172
+ # handle corner cases max_int is 0 or exact power of 10
173
+ max_int += 1e-8
174
+
175
+ pad_length = int(math.ceil(math.log10(max_int)))
176
+
177
+ def name_function(i: int) -> str:
178
+ return str(i).zfill(pad_length)
179
+
180
+ return name_function
181
+
182
+
183
+ def seek_delimiter(file: IO[bytes], delimiter: bytes, blocksize: int) -> bool:
184
+ r"""Seek current file to file start, file end, or byte after delimiter seq.
185
+
186
+ Seeks file to next chunk delimiter, where chunks are defined on file start,
187
+ a delimiting sequence, and file end. Use file.tell() to see location afterwards.
188
+ Note that file start is a valid split, so must be at offset > 0 to seek for
189
+ delimiter.
190
+
191
+ Parameters
192
+ ----------
193
+ file: a file
194
+ delimiter: bytes
195
+ a delimiter like ``b'\n'`` or message sentinel, matching file .read() type
196
+ blocksize: int
197
+ Number of bytes to read from the file at once.
198
+
199
+
200
+ Returns
201
+ -------
202
+ Returns True if a delimiter was found, False if at file start or end.
203
+
204
+ """
205
+
206
+ if file.tell() == 0:
207
+ # beginning-of-file, return without seek
208
+ return False
209
+
210
+ # Interface is for binary IO, with delimiter as bytes, but initialize last
211
+ # with result of file.read to preserve compatibility with text IO.
212
+ last: bytes | None = None
213
+ while True:
214
+ current = file.read(blocksize)
215
+ if not current:
216
+ # end-of-file without delimiter
217
+ return False
218
+ full = last + current if last else current
219
+ try:
220
+ if delimiter in full:
221
+ i = full.index(delimiter)
222
+ file.seek(file.tell() - (len(full) - i) + len(delimiter))
223
+ return True
224
+ elif len(current) < blocksize:
225
+ # end-of-file without delimiter
226
+ return False
227
+ except (OSError, ValueError):
228
+ pass
229
+ last = full[-len(delimiter) :]
230
+
231
+
232
+ def read_block(
233
+ f: IO[bytes],
234
+ offset: int,
235
+ length: int | None,
236
+ delimiter: bytes | None = None,
237
+ split_before: bool = False,
238
+ ) -> bytes:
239
+ """Read a block of bytes from a file
240
+
241
+ Parameters
242
+ ----------
243
+ f: File
244
+ Open file
245
+ offset: int
246
+ Byte offset to start read
247
+ length: int
248
+ Number of bytes to read, read through end of file if None
249
+ delimiter: bytes (optional)
250
+ Ensure reading starts and stops at delimiter bytestring
251
+ split_before: bool (optional)
252
+ Start/stop read *before* delimiter bytestring.
253
+
254
+
255
+ If using the ``delimiter=`` keyword argument we ensure that the read
256
+ starts and stops at delimiter boundaries that follow the locations
257
+ ``offset`` and ``offset + length``. If ``offset`` is zero then we
258
+ start at zero, regardless of delimiter. The bytestring returned WILL
259
+ include the terminating delimiter string.
260
+
261
+ Examples
262
+ --------
263
+
264
+ >>> from io import BytesIO # doctest: +SKIP
265
+ >>> f = BytesIO(b'Alice, 100\\nBob, 200\\nCharlie, 300') # doctest: +SKIP
266
+ >>> read_block(f, 0, 13) # doctest: +SKIP
267
+ b'Alice, 100\\nBo'
268
+
269
+ >>> read_block(f, 0, 13, delimiter=b'\\n') # doctest: +SKIP
270
+ b'Alice, 100\\nBob, 200\\n'
271
+
272
+ >>> read_block(f, 10, 10, delimiter=b'\\n') # doctest: +SKIP
273
+ b'Bob, 200\\nCharlie, 300'
274
+ """
275
+ if delimiter:
276
+ f.seek(offset)
277
+ found_start_delim = seek_delimiter(f, delimiter, 2**16)
278
+ if length is None:
279
+ return f.read()
280
+ start = f.tell()
281
+ length -= start - offset
282
+
283
+ f.seek(start + length)
284
+ found_end_delim = seek_delimiter(f, delimiter, 2**16)
285
+ end = f.tell()
286
+
287
+ # Adjust split location to before delimiter if seek found the
288
+ # delimiter sequence, not start or end of file.
289
+ if found_start_delim and split_before:
290
+ start -= len(delimiter)
291
+
292
+ if found_end_delim and split_before:
293
+ end -= len(delimiter)
294
+
295
+ offset = start
296
+ length = end - start
297
+
298
+ f.seek(offset)
299
+
300
+ # TODO: allow length to be None and read to the end of the file?
301
+ assert length is not None
302
+ b = f.read(length)
303
+ return b
304
+
305
+
306
+ def tokenize(*args: Any, **kwargs: Any) -> str:
307
+ """Deterministic token
308
+
309
+ (modified from dask.base)
310
+
311
+ >>> tokenize([1, 2, '3'])
312
+ '9d71491b50023b06fc76928e6eddb952'
313
+
314
+ >>> tokenize('Hello') == tokenize('Hello')
315
+ True
316
+ """
317
+ if kwargs:
318
+ args += (kwargs,)
319
+ try:
320
+ h = md5(str(args).encode())
321
+ except ValueError:
322
+ # FIPS systems: https://github.com/fsspec/filesystem_spec/issues/380
323
+ h = md5(str(args).encode(), usedforsecurity=False)
324
+ return h.hexdigest()
325
+
326
+
327
+ def stringify_path(filepath: str | os.PathLike[str] | pathlib.Path) -> str:
328
+ """Attempt to convert a path-like object to a string.
329
+
330
+ Parameters
331
+ ----------
332
+ filepath: object to be converted
333
+
334
+ Returns
335
+ -------
336
+ filepath_str: maybe a string version of the object
337
+
338
+ Notes
339
+ -----
340
+ Objects supporting the fspath protocol are coerced according to its
341
+ __fspath__ method.
342
+
343
+ For backwards compatibility with older Python version, pathlib.Path
344
+ objects are specially coerced.
345
+
346
+ Any other object is passed through unchanged, which includes bytes,
347
+ strings, buffers, or anything else that's not even path-like.
348
+ """
349
+ if isinstance(filepath, str):
350
+ return filepath
351
+ elif hasattr(filepath, "__fspath__"):
352
+ return filepath.__fspath__()
353
+ elif hasattr(filepath, "path"):
354
+ return filepath.path
355
+ else:
356
+ return filepath # type: ignore[return-value]
357
+
358
+
359
+ def make_instance(
360
+ cls: Callable[..., T], args: Sequence[Any], kwargs: dict[str, Any]
361
+ ) -> T:
362
+ inst = cls(*args, **kwargs)
363
+ inst._determine_worker() # type: ignore[attr-defined]
364
+ return inst
365
+
366
+
367
+ def common_prefix(paths: Iterable[str]) -> str:
368
+ """For a list of paths, find the shortest prefix common to all"""
369
+ parts = [p.split("/") for p in paths]
370
+ lmax = min(len(p) for p in parts)
371
+ end = 0
372
+ for i in range(lmax):
373
+ end = all(p[i] == parts[0][i] for p in parts)
374
+ if not end:
375
+ break
376
+ i += end
377
+ return "/".join(parts[0][:i])
378
+
379
+
380
+ def other_paths(
381
+ paths: list[str],
382
+ path2: str | list[str],
383
+ exists: bool = False,
384
+ flatten: bool = False,
385
+ ) -> list[str]:
386
+ """In bulk file operations, construct a new file tree from a list of files
387
+
388
+ Parameters
389
+ ----------
390
+ paths: list of str
391
+ The input file tree
392
+ path2: str or list of str
393
+ Root to construct the new list in. If this is already a list of str, we just
394
+ assert it has the right number of elements.
395
+ exists: bool (optional)
396
+ For a str destination, it is already exists (and is a dir), files should
397
+ end up inside.
398
+ flatten: bool (optional)
399
+ Whether to flatten the input directory tree structure so that the output files
400
+ are in the same directory.
401
+
402
+ Returns
403
+ -------
404
+ list of str
405
+ """
406
+
407
+ if isinstance(path2, str):
408
+ path2 = path2.rstrip("/")
409
+
410
+ if flatten:
411
+ path2 = ["/".join((path2, p.split("/")[-1])) for p in paths]
412
+ else:
413
+ cp = common_prefix(paths)
414
+ if exists:
415
+ cp = cp.rsplit("/", 1)[0]
416
+ if not cp and all(not s.startswith("/") for s in paths):
417
+ path2 = ["/".join([path2, p]) for p in paths]
418
+ else:
419
+ path2 = [p.replace(cp, path2, 1) for p in paths]
420
+ else:
421
+ assert len(paths) == len(path2)
422
+ return path2
423
+
424
+
425
+ def is_exception(obj: Any) -> bool:
426
+ return isinstance(obj, BaseException)
427
+
428
+
429
+ def isfilelike(f: Any) -> TypeGuard[IO[bytes]]:
430
+ return all(hasattr(f, attr) for attr in ["read", "close", "tell"])
431
+
432
+
433
+ def get_protocol(url: str) -> str:
434
+ url = stringify_path(url)
435
+ parts = re.split(r"(\:\:|\://)", url, maxsplit=1)
436
+ if len(parts) > 1:
437
+ return parts[0]
438
+ return "file"
439
+
440
+
441
+ def can_be_local(path: str) -> bool:
442
+ """Can the given URL be used with open_local?"""
443
+ from fsspec import get_filesystem_class
444
+
445
+ try:
446
+ return getattr(get_filesystem_class(get_protocol(path)), "local_file", False)
447
+ except (ValueError, ImportError):
448
+ # not in registry or import failed
449
+ return False
450
+
451
+
452
+ def get_package_version_without_import(name: str) -> str | None:
453
+ """For given package name, try to find the version without importing it
454
+
455
+ Import and package.__version__ is still the backup here, so an import
456
+ *might* happen.
457
+
458
+ Returns either the version string, or None if the package
459
+ or the version was not readily found.
460
+ """
461
+ if name in sys.modules:
462
+ mod = sys.modules[name]
463
+ if hasattr(mod, "__version__"):
464
+ return mod.__version__
465
+ try:
466
+ return version(name)
467
+ except: # noqa: E722
468
+ pass
469
+ try:
470
+ import importlib
471
+
472
+ mod = importlib.import_module(name)
473
+ return mod.__version__
474
+ except (ImportError, AttributeError):
475
+ return None
476
+
477
+
478
+ def setup_logging(
479
+ logger: logging.Logger | None = None,
480
+ logger_name: str | None = None,
481
+ level: str = "DEBUG",
482
+ clear: bool = True,
483
+ ) -> logging.Logger:
484
+ if logger is None and logger_name is None:
485
+ raise ValueError("Provide either logger object or logger name")
486
+ logger = logger or logging.getLogger(logger_name)
487
+ handle = logging.StreamHandler()
488
+ formatter = logging.Formatter(
489
+ "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s -- %(message)s"
490
+ )
491
+ handle.setFormatter(formatter)
492
+ if clear:
493
+ logger.handlers.clear()
494
+ logger.addHandler(handle)
495
+ logger.setLevel(level)
496
+ return logger
497
+
498
+
499
+ def _unstrip_protocol(name: str, fs: AbstractFileSystem) -> str:
500
+ return fs.unstrip_protocol(name)
501
+
502
+
503
+ def mirror_from(
504
+ origin_name: str, methods: Iterable[str]
505
+ ) -> Callable[[type[T]], type[T]]:
506
+ """Mirror attributes and methods from the given
507
+ origin_name attribute of the instance to the
508
+ decorated class"""
509
+
510
+ def origin_getter(method: str, self: Any) -> Any:
511
+ origin = getattr(self, origin_name)
512
+ return getattr(origin, method)
513
+
514
+ def wrapper(cls: type[T]) -> type[T]:
515
+ for method in methods:
516
+ wrapped_method = partial(origin_getter, method)
517
+ setattr(cls, method, property(wrapped_method))
518
+ return cls
519
+
520
+ return wrapper
521
+
522
+
523
+ @contextlib.contextmanager
524
+ def nullcontext(obj: T) -> Iterator[T]:
525
+ yield obj
526
+
527
+
528
+ def merge_offset_ranges(
529
+ paths: list[str],
530
+ starts: list[int] | int,
531
+ ends: list[int] | int,
532
+ max_gap: int = 0,
533
+ max_block: int | None = None,
534
+ sort: bool = True,
535
+ ) -> tuple[list[str], list[int], list[int]]:
536
+ """Merge adjacent byte-offset ranges when the inter-range
537
+ gap is <= `max_gap`, and when the merged byte range does not
538
+ exceed `max_block` (if specified). By default, this function
539
+ will re-order the input paths and byte ranges to ensure sorted
540
+ order. If the user can guarantee that the inputs are already
541
+ sorted, passing `sort=False` will skip the re-ordering.
542
+ """
543
+ # Check input
544
+ if not isinstance(paths, list):
545
+ raise TypeError
546
+ if not isinstance(starts, list):
547
+ starts = [starts] * len(paths)
548
+ if not isinstance(ends, list):
549
+ ends = [ends] * len(paths)
550
+ if len(starts) != len(paths) or len(ends) != len(paths):
551
+ raise ValueError
552
+
553
+ # Early Return
554
+ if len(starts) <= 1:
555
+ return paths, starts, ends
556
+
557
+ starts = [s or 0 for s in starts]
558
+ # Sort by paths and then ranges if `sort=True`
559
+ if sort:
560
+ paths, starts, ends = (
561
+ list(v)
562
+ for v in zip(
563
+ *sorted(
564
+ zip(paths, starts, ends),
565
+ )
566
+ )
567
+ )
568
+
569
+ if paths:
570
+ # Loop through the coupled `paths`, `starts`, and
571
+ # `ends`, and merge adjacent blocks when appropriate
572
+ new_paths = paths[:1]
573
+ new_starts = starts[:1]
574
+ new_ends = ends[:1]
575
+ for i in range(1, len(paths)):
576
+ if paths[i] == paths[i - 1] and new_ends[-1] is None:
577
+ continue
578
+ elif (
579
+ paths[i] != paths[i - 1]
580
+ or ((starts[i] - new_ends[-1]) > max_gap)
581
+ or (max_block is not None and (ends[i] - new_starts[-1]) > max_block)
582
+ ):
583
+ # Cannot merge with previous block.
584
+ # Add new `paths`, `starts`, and `ends` elements
585
+ new_paths.append(paths[i])
586
+ new_starts.append(starts[i])
587
+ new_ends.append(ends[i])
588
+ else:
589
+ # Merge with previous block by updating the
590
+ # last element of `ends`
591
+ new_ends[-1] = ends[i]
592
+ return new_paths, new_starts, new_ends
593
+
594
+ # `paths` is empty. Just return input lists
595
+ return paths, starts, ends
596
+
597
+
598
+ def file_size(filelike: IO[bytes]) -> int:
599
+ """Find length of any open read-mode file-like"""
600
+ pos = filelike.tell()
601
+ try:
602
+ return filelike.seek(0, 2)
603
+ finally:
604
+ filelike.seek(pos)
605
+
606
+
607
+ @contextlib.contextmanager
608
+ def atomic_write(path: str, mode: str = "wb"):
609
+ """
610
+ A context manager that opens a temporary file next to `path` and, on exit,
611
+ replaces `path` with the temporary file, thereby updating `path`
612
+ atomically.
613
+ """
614
+ fd, fn = tempfile.mkstemp(
615
+ dir=os.path.dirname(path), prefix=os.path.basename(path) + "-"
616
+ )
617
+ try:
618
+ with open(fd, mode) as fp:
619
+ yield fp
620
+ except BaseException:
621
+ with contextlib.suppress(FileNotFoundError):
622
+ os.unlink(fn)
623
+ raise
624
+ else:
625
+ os.replace(fn, path)
626
+
627
+
628
+ def _translate(pat, STAR, QUESTION_MARK):
629
+ # Copied from: https://github.com/python/cpython/pull/106703.
630
+ res: list[str] = []
631
+ add = res.append
632
+ i, n = 0, len(pat)
633
+ while i < n:
634
+ c = pat[i]
635
+ i = i + 1
636
+ if c == "*":
637
+ # compress consecutive `*` into one
638
+ if (not res) or res[-1] is not STAR:
639
+ add(STAR)
640
+ elif c == "?":
641
+ add(QUESTION_MARK)
642
+ elif c == "[":
643
+ j = i
644
+ if j < n and pat[j] == "!":
645
+ j = j + 1
646
+ if j < n and pat[j] == "]":
647
+ j = j + 1
648
+ while j < n and pat[j] != "]":
649
+ j = j + 1
650
+ if j >= n:
651
+ add("\\[")
652
+ else:
653
+ stuff = pat[i:j]
654
+ if "-" not in stuff:
655
+ stuff = stuff.replace("\\", r"\\")
656
+ else:
657
+ chunks = []
658
+ k = i + 2 if pat[i] == "!" else i + 1
659
+ while True:
660
+ k = pat.find("-", k, j)
661
+ if k < 0:
662
+ break
663
+ chunks.append(pat[i:k])
664
+ i = k + 1
665
+ k = k + 3
666
+ chunk = pat[i:j]
667
+ if chunk:
668
+ chunks.append(chunk)
669
+ else:
670
+ chunks[-1] += "-"
671
+ # Remove empty ranges -- invalid in RE.
672
+ for k in range(len(chunks) - 1, 0, -1):
673
+ if chunks[k - 1][-1] > chunks[k][0]:
674
+ chunks[k - 1] = chunks[k - 1][:-1] + chunks[k][1:]
675
+ del chunks[k]
676
+ # Escape backslashes and hyphens for set difference (--).
677
+ # Hyphens that create ranges shouldn't be escaped.
678
+ stuff = "-".join(
679
+ s.replace("\\", r"\\").replace("-", r"\-") for s in chunks
680
+ )
681
+ # Escape set operations (&&, ~~ and ||).
682
+ stuff = re.sub(r"([&~|])", r"\\\1", stuff)
683
+ i = j + 1
684
+ if not stuff:
685
+ # Empty range: never match.
686
+ add("(?!)")
687
+ elif stuff == "!":
688
+ # Negated empty range: match any character.
689
+ add(".")
690
+ else:
691
+ if stuff[0] == "!":
692
+ stuff = "^" + stuff[1:]
693
+ elif stuff[0] in ("^", "["):
694
+ stuff = "\\" + stuff
695
+ add(f"[{stuff}]")
696
+ else:
697
+ add(re.escape(c))
698
+ assert i == n
699
+ return res
700
+
701
+
702
+ def glob_translate(pat):
703
+ # Copied from: https://github.com/python/cpython/pull/106703.
704
+ # The keyword parameters' values are fixed to:
705
+ # recursive=True, include_hidden=True, seps=None
706
+ """Translate a pathname with shell wildcards to a regular expression."""
707
+ if os.path.altsep:
708
+ seps = os.path.sep + os.path.altsep
709
+ else:
710
+ seps = os.path.sep
711
+ escaped_seps = "".join(map(re.escape, seps))
712
+ any_sep = f"[{escaped_seps}]" if len(seps) > 1 else escaped_seps
713
+ not_sep = f"[^{escaped_seps}]"
714
+ one_last_segment = f"{not_sep}+"
715
+ one_segment = f"{one_last_segment}{any_sep}"
716
+ any_segments = f"(?:.+{any_sep})?"
717
+ any_last_segments = ".*"
718
+ results = []
719
+ parts = re.split(any_sep, pat)
720
+ last_part_idx = len(parts) - 1
721
+ for idx, part in enumerate(parts):
722
+ if part == "*":
723
+ results.append(one_segment if idx < last_part_idx else one_last_segment)
724
+ continue
725
+ if part == "**":
726
+ results.append(any_segments if idx < last_part_idx else any_last_segments)
727
+ continue
728
+ elif "**" in part:
729
+ raise ValueError(
730
+ "Invalid pattern: '**' can only be an entire path component"
731
+ )
732
+ if part:
733
+ results.extend(_translate(part, f"{not_sep}*", not_sep))
734
+ if idx < last_part_idx:
735
+ results.append(any_sep)
736
+ res = "".join(results)
737
+ return rf"(?s:{res})\Z"