camenduru commited on
Commit
b68d3d7
·
1 Parent(s): f19d49c

Delete utilities.py

Browse files
Files changed (1) hide show
  1. utilities.py +0 -537
utilities.py DELETED
@@ -1,537 +0,0 @@
1
- #
2
- # Copyright 2022 The HuggingFace Inc. team.
3
- # SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: Apache-2.0
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
- #
18
-
19
- from collections import OrderedDict
20
- from copy import copy
21
- import numpy as np
22
- import os
23
- import math
24
- from PIL import Image
25
- from polygraphy.backend.common import bytes_from_path
26
- from polygraphy.backend.trt import CreateConfig, Profile
27
- from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine
28
- from polygraphy.backend.trt import util as trt_util
29
- from polygraphy import cuda
30
- import random
31
- from scipy import integrate
32
- import tensorrt as trt
33
- import torch
34
-
35
- TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
36
-
37
- class Engine():
38
- def __init__(
39
- self,
40
- model_name,
41
- engine_dir,
42
- ):
43
- self.engine_path = os.path.join(engine_dir, model_name+'.plan')
44
- self.engine = None
45
- self.context = None
46
- self.buffers = OrderedDict()
47
- self.tensors = OrderedDict()
48
-
49
- def __del__(self):
50
- [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray) ]
51
- del self.engine
52
- del self.context
53
- del self.buffers
54
- del self.tensors
55
-
56
- def build(self, onnx_path, fp16, input_profile=None, enable_preview=False):
57
- print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
58
- p = Profile()
59
- if input_profile:
60
- for name, dims in input_profile.items():
61
- assert len(dims) == 3
62
- p.add(name, min=dims[0], opt=dims[1], max=dims[2])
63
-
64
- preview_features = []
65
- if enable_preview:
66
- trt_version = [int(i) for i in trt.__version__.split(".")]
67
- # FASTER_DYNAMIC_SHAPES_0805 should only be used for TRT 8.5.1 or above.
68
- if trt_version[0] > 8 or \
69
- (trt_version[0] == 8 and (trt_version[1] > 5 or (trt_version[1] == 5 and trt_version[2] >= 1))):
70
- preview_features = [trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805]
71
-
72
- engine = engine_from_network(network_from_onnx_path(onnx_path), config=CreateConfig(fp16=fp16, profiles=[p],
73
- preview_features=preview_features))
74
- save_engine(engine, path=self.engine_path)
75
-
76
- def activate(self):
77
- print(f"Loading TensorRT engine: {self.engine_path}")
78
- self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
79
- self.context = self.engine.create_execution_context()
80
-
81
- def allocate_buffers(self, shape_dict=None, device='cuda'):
82
- for idx in range(trt_util.get_bindings_per_profile(self.engine)):
83
- binding = self.engine[idx]
84
- if shape_dict and binding in shape_dict:
85
- shape = shape_dict[binding]
86
- else:
87
- shape = self.engine.get_binding_shape(binding)
88
- dtype = trt_util.np_dtype_from_trt(self.engine.get_binding_dtype(binding))
89
- if self.engine.binding_is_input(binding):
90
- self.context.set_binding_shape(idx, shape)
91
- # Workaround to convert np dtype to torch
92
- np_type_tensor = np.empty(shape=[], dtype=dtype)
93
- torch_type_tensor = torch.from_numpy(np_type_tensor)
94
- tensor = torch.empty(tuple(shape), dtype=torch_type_tensor.dtype).to(device=device)
95
- self.tensors[binding] = tensor
96
- self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
97
-
98
- def infer(self, feed_dict, stream):
99
- start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
100
- # shallow copy of ordered dict
101
- device_buffers = copy(self.buffers)
102
- for name, buf in feed_dict.items():
103
- assert isinstance(buf, cuda.DeviceView)
104
- device_buffers[name] = buf
105
- bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
106
- noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
107
- if not noerror:
108
- raise ValueError(f"ERROR: inference failed.")
109
-
110
- return self.tensors
111
-
112
- class LMSDiscreteScheduler():
113
- def __init__(
114
- self,
115
- device = 'cuda',
116
- beta_start = 0.00085,
117
- beta_end = 0.012,
118
- num_train_timesteps = 1000,
119
- ):
120
- self.num_train_timesteps = num_train_timesteps
121
- self.order = 4
122
-
123
- self.beta_start = beta_start
124
- self.beta_end = beta_end
125
- betas = (torch.linspace(beta_start**0.5, beta_end**0.5, self.num_train_timesteps, dtype=torch.float32) ** 2)
126
- alphas = 1.0 - betas
127
- self.alphas_cumprod = torch.cumprod(alphas, dim=0)
128
-
129
- sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
130
- sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
131
- self.sigmas = torch.from_numpy(sigmas)
132
-
133
- # standard deviation of the initial noise distribution
134
- self.init_noise_sigma = self.sigmas.max()
135
-
136
- self.device = device
137
-
138
- def set_timesteps(self, steps):
139
- self.num_inference_steps = steps
140
-
141
- timesteps = np.linspace(0, self.num_train_timesteps - 1, steps, dtype=float)[::-1].copy()
142
- sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
143
- sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
144
- sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
145
- self.sigmas = torch.from_numpy(sigmas).to(device=self.device)
146
-
147
- # Move all timesteps to correct device beforehand
148
- self.timesteps = torch.from_numpy(timesteps).to(device=self.device).float()
149
- self.derivatives = []
150
-
151
- def scale_model_input(self, sample: torch.FloatTensor, idx, *args, **kwargs) -> torch.FloatTensor:
152
- return sample * self.latent_scales[idx]
153
-
154
- def configure(self):
155
- order = self.order
156
- self.lms_coeffs = []
157
- self.latent_scales = [1./((sigma**2 + 1) ** 0.5) for sigma in self.sigmas]
158
-
159
- def get_lms_coefficient(order, t, current_order):
160
- """
161
- Compute a linear multistep coefficient.
162
- """
163
- def lms_derivative(tau):
164
- prod = 1.0
165
- for k in range(order):
166
- if current_order == k:
167
- continue
168
- prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
169
- return prod
170
- integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]
171
- return integrated_coeff
172
-
173
- for step_index in range(self.num_inference_steps):
174
- order = min(step_index + 1, order)
175
- self.lms_coeffs.append([get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)])
176
-
177
- def step(self, output, latents, idx, timestep):
178
- # compute the previous noisy sample x_t -> x_t-1
179
- # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
180
- sigma = self.sigmas[idx]
181
- pred_original_sample = latents - sigma * output
182
- # 2. Convert to an ODE derivative
183
- derivative = (latents - pred_original_sample) / sigma
184
- self.derivatives.append(derivative)
185
- if len(self.derivatives) > self.order:
186
- self.derivatives.pop(0)
187
- # 3. Compute previous sample based on the derivatives path
188
- prev_sample = latents + sum(
189
- coeff * derivative for coeff, derivative in zip(self.lms_coeffs[idx], reversed(self.derivatives))
190
- )
191
-
192
- return prev_sample
193
-
194
- class DPMScheduler():
195
- def __init__(
196
- self,
197
- beta_start = 0.00085,
198
- beta_end = 0.012,
199
- num_train_timesteps = 1000,
200
- solver_order = 2,
201
- predict_epsilon = True,
202
- thresholding = False,
203
- dynamic_thresholding_ratio = 0.995,
204
- sample_max_value = 1.0,
205
- algorithm_type = "dpmsolver++",
206
- solver_type = "midpoint",
207
- lower_order_final = True,
208
- device = 'cuda',
209
- ):
210
- # this schedule is very specific to the latent diffusion model.
211
- self.betas = (
212
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
213
- )
214
-
215
- self.device = device
216
- self.alphas = 1.0 - self.betas
217
- self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
218
- # Currently we only support VP-type noise schedule
219
- self.alpha_t = torch.sqrt(self.alphas_cumprod)
220
- self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
221
- self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
222
-
223
- # standard deviation of the initial noise distribution
224
- self.init_noise_sigma = 1.0
225
-
226
- self.algorithm_type = algorithm_type
227
- self.predict_epsilon = predict_epsilon
228
- self.thresholding = thresholding
229
- self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
230
- self.sample_max_value = sample_max_value
231
- self.lower_order_final = lower_order_final
232
-
233
- # settings for DPM-Solver
234
- if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
235
- raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
236
- if solver_type not in ["midpoint", "heun"]:
237
- raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
238
-
239
- # setable values
240
- self.num_inference_steps = None
241
- self.solver_order = solver_order
242
- self.num_train_timesteps = num_train_timesteps
243
- self.solver_type = solver_type
244
-
245
- self.first_order_first_coef = []
246
- self.first_order_second_coef = []
247
-
248
- self.second_order_first_coef = []
249
- self.second_order_second_coef = []
250
- self.second_order_third_coef = []
251
-
252
- self.third_order_first_coef = []
253
- self.third_order_second_coef = []
254
- self.third_order_third_coef = []
255
- self.third_order_fourth_coef = []
256
-
257
- def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
258
- return sample
259
-
260
- def configure(self):
261
- lower_order_nums = 0
262
- for step_index in range(self.num_inference_steps):
263
- step_idx = step_index
264
- timestep = self.timesteps[step_idx]
265
-
266
- prev_timestep = 0 if step_idx == len(self.timesteps) - 1 else self.timesteps[step_idx + 1]
267
-
268
- self.dpm_solver_first_order_coefs_precompute(timestep, prev_timestep)
269
-
270
- timestep_list = [self.timesteps[step_index - 1], timestep]
271
- self.multistep_dpm_solver_second_order_coefs_precompute(timestep_list, prev_timestep)
272
-
273
- timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
274
- self.multistep_dpm_solver_third_order_coefs_precompute(timestep_list, prev_timestep)
275
-
276
- if lower_order_nums < self.solver_order:
277
- lower_order_nums += 1
278
-
279
- def dpm_solver_first_order_coefs_precompute(self, timestep, prev_timestep):
280
- lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
281
- alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
282
- sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
283
- h = lambda_t - lambda_s
284
- if self.algorithm_type == "dpmsolver++":
285
- self.first_order_first_coef.append(sigma_t / sigma_s)
286
- self.first_order_second_coef.append(alpha_t * (torch.exp(-h) - 1.0))
287
- elif self.algorithm_type == "dpmsolver":
288
- self.first_order_first_coef.append(alpha_t / alpha_s)
289
- self.first_order_second_coef.append(sigma_t * (torch.exp(h) - 1.0))
290
-
291
- def multistep_dpm_solver_second_order_coefs_precompute(self, timestep_list, prev_timestep):
292
- t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
293
- lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
294
- alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
295
- sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
296
- h = lambda_t - lambda_s0
297
- if self.algorithm_type == "dpmsolver++":
298
- # See https://arxiv.org/abs/2211.01095 for detailed derivations
299
- if self.solver_type == "midpoint":
300
- self.second_order_first_coef.append(sigma_t / sigma_s0)
301
- self.second_order_second_coef.append((alpha_t * (torch.exp(-h) - 1.0)))
302
- self.second_order_third_coef.append(0.5 * (alpha_t * (torch.exp(-h) - 1.0)))
303
- elif self.solver_type == "heun":
304
- self.second_order_first_coef.append(sigma_t / sigma_s0)
305
- self.second_order_second_coef.append((alpha_t * (torch.exp(-h) - 1.0)))
306
- self.second_order_third_coef.append(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0))
307
- elif self.algorithm_type == "dpmsolver":
308
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
309
- if self.solver_type == "midpoint":
310
- self.second_order_first_coef.append(alpha_t / alpha_s0)
311
- self.second_order_second_coef.append((sigma_t * (torch.exp(h) - 1.0)))
312
- self.second_order_third_coef.append(0.5 * (sigma_t * (torch.exp(h) - 1.0)))
313
- elif self.solver_type == "heun":
314
- self.second_order_first_coef.append(alpha_t / alpha_s0)
315
- self.second_order_second_coef.append((sigma_t * (torch.exp(h) - 1.0)))
316
- self.second_order_third_coef.append((sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)))
317
-
318
- def multistep_dpm_solver_third_order_coefs_precompute(self, timestep_list, prev_timestep):
319
- t, s0 = prev_timestep, timestep_list[-1]
320
- lambda_t, lambda_s0 = (
321
- self.lambda_t[t],
322
- self.lambda_t[s0]
323
- )
324
- alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
325
- sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
326
- h = lambda_t - lambda_s0
327
- if self.algorithm_type == "dpmsolver++":
328
- self.third_order_first_coef.append(sigma_t / sigma_s0)
329
- self.third_order_second_coef.append(alpha_t * (torch.exp(-h) - 1.0))
330
- self.third_order_third_coef.append(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0))
331
- self.third_order_fourth_coef.append(alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5))
332
- elif self.algorithm_type == "dpmsolver":
333
- self.third_order_first_coef.append(alpha_t / alpha_s0)
334
- self.third_order_second_coef.append(sigma_t * (torch.exp(h) - 1.0))
335
- self.third_order_third_coef.append(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0))
336
- self.third_order_fourth_coef.append(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5))
337
-
338
- def set_timesteps(self, num_inference_steps):
339
- self.num_inference_steps = num_inference_steps
340
- timesteps = (
341
- np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
342
- .round()[::-1][:-1]
343
- .copy()
344
- .astype(np.int32)
345
- )
346
- self.timesteps = torch.from_numpy(timesteps).to(self.device)
347
- self.model_outputs = [
348
- None,
349
- ] * self.solver_order
350
- self.lower_order_nums = 0
351
-
352
- def convert_model_output(
353
- self, model_output, timestep, sample
354
- ):
355
- # DPM-Solver++ needs to solve an integral of the data prediction model.
356
- if self.algorithm_type == "dpmsolver++":
357
- if self.predict_epsilon:
358
- alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
359
- x0_pred = (sample - sigma_t * model_output) / alpha_t
360
- else:
361
- x0_pred = model_output
362
- if self.thresholding:
363
- # Dynamic thresholding in https://arxiv.org/abs/2205.11487
364
- dynamic_max_val = torch.quantile(
365
- torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.dynamic_thresholding_ratio, dim=1
366
- )
367
- dynamic_max_val = torch.maximum(
368
- dynamic_max_val,
369
- self.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
370
- )[(...,) + (None,) * (x0_pred.ndim - 1)]
371
- x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
372
- return x0_pred
373
- # DPM-Solver needs to solve an integral of the noise prediction model.
374
- elif self.algorithm_type == "dpmsolver":
375
- if self.predict_epsilon:
376
- return model_output
377
- else:
378
- alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
379
- epsilon = (sample - alpha_t * model_output) / sigma_t
380
- return epsilon
381
-
382
- def dpm_solver_first_order_update(
383
- self,
384
- idx,
385
- model_output,
386
- sample
387
- ):
388
- first_coef = self.first_order_first_coef[idx]
389
- second_coef = self.first_order_second_coef[idx]
390
-
391
- if self.algorithm_type == "dpmsolver++":
392
- x_t = first_coef * sample - second_coef * model_output
393
- elif self.algorithm_type == "dpmsolver":
394
- x_t = first_coef * sample - second_coef * model_output
395
- return x_t
396
-
397
- def multistep_dpm_solver_second_order_update(
398
- self,
399
- idx,
400
- model_output_list,
401
- timestep_list,
402
- prev_timestep,
403
- sample
404
- ):
405
- t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
406
- m0, m1 = model_output_list[-1], model_output_list[-2]
407
- lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
408
- h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
409
- r0 = h_0 / h
410
- D0, D1 = m0, (1.0 / r0) * (m0 - m1)
411
-
412
- first_coef = self.second_order_first_coef[idx]
413
- second_coef = self.second_order_second_coef[idx]
414
- third_coef = self.second_order_third_coef[idx]
415
-
416
- if self.algorithm_type == "dpmsolver++":
417
- # See https://arxiv.org/abs/2211.01095 for detailed derivations
418
- if self.solver_type == "midpoint":
419
- x_t = (
420
- first_coef * sample
421
- - second_coef * D0
422
- - third_coef * D1
423
- )
424
- elif self.solver_type == "heun":
425
- x_t = (
426
- first_coef * sample
427
- - second_coef * D0
428
- + third_coef * D1
429
- )
430
- elif self.algorithm_type == "dpmsolver":
431
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
432
- if self.solver_type == "midpoint":
433
- x_t = (
434
- first_coef * sample
435
- - second_coef * D0
436
- - third_coef * D1
437
- )
438
- elif self.solver_type == "heun":
439
- x_t = (
440
- first_coef * sample
441
- - second_coef * D0
442
- - third_coef * D1
443
- )
444
- return x_t
445
-
446
- def multistep_dpm_solver_third_order_update(
447
- self,
448
- idx,
449
- model_output_list,
450
- timestep_list,
451
- prev_timestep,
452
- sample
453
- ):
454
- t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
455
- m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
456
- lambda_t, lambda_s0, lambda_s1, lambda_s2 = (
457
- self.lambda_t[t],
458
- self.lambda_t[s0],
459
- self.lambda_t[s1],
460
- self.lambda_t[s2],
461
- )
462
- h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
463
- r0, r1 = h_0 / h, h_1 / h
464
- D0 = m0
465
- D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
466
- D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
467
- D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
468
-
469
- first_coef = self.third_order_first_coef[idx]
470
- second_coef = self.third_order_second_coef[idx]
471
- third_coef = self.third_order_third_coef[idx]
472
- fourth_coef = self.third_order_fourth_coef[idx]
473
-
474
- if self.algorithm_type == "dpmsolver++":
475
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
476
- x_t = (
477
- first_coef * sample
478
- - second_coef * D0
479
- + third_coef * D1
480
- - fourth_coef * D2
481
- )
482
- elif self.algorithm_type == "dpmsolver":
483
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
484
- x_t = (
485
- first_coef * sample
486
- - second_coef * D0
487
- - third_coef * D1
488
- - fourth_coef * D2
489
- )
490
- return x_t
491
-
492
- def step(self, output, latents, step_index, timestep):
493
- if self.num_inference_steps is None:
494
- raise ValueError(
495
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
496
- )
497
-
498
- prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
499
- lower_order_final = (
500
- (step_index == len(self.timesteps) - 1) and self.lower_order_final and len(self.timesteps) < 15
501
- )
502
- lower_order_second = (
503
- (step_index == len(self.timesteps) - 2) and self.lower_order_final and len(self.timesteps) < 15
504
- )
505
-
506
- output = self.convert_model_output(output, timestep, latents)
507
- for i in range(self.solver_order - 1):
508
- self.model_outputs[i] = self.model_outputs[i + 1]
509
- self.model_outputs[-1] = output
510
-
511
- if self.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
512
- prev_sample = self.dpm_solver_first_order_update(step_index, output, latents)
513
- elif self.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
514
- timestep_list = [self.timesteps[step_index - 1], timestep]
515
- prev_sample = self.multistep_dpm_solver_second_order_update(
516
- step_index, self.model_outputs, timestep_list, prev_timestep, latents
517
- )
518
- else:
519
- timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
520
- prev_sample = self.multistep_dpm_solver_third_order_update(
521
- step_index, self.model_outputs, timestep_list, prev_timestep, latents
522
- )
523
-
524
- if self.lower_order_nums < self.solver_order:
525
- self.lower_order_nums += 1
526
-
527
- return prev_sample
528
-
529
- def save_image(images, image_path_dir, image_name_prefix):
530
- """
531
- Save the generated images to png files.
532
- """
533
- images = ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
534
- for i in range(images.shape[0]):
535
- image_path = os.path.join(image_path_dir, image_name_prefix+str(i+1)+'-'+str(random.randint(1000,9999))+'.png')
536
- print(f"Saving image {i+1} / {images.shape[0]} to: {image_path}")
537
- Image.fromarray(images[i]).save(image_path)