| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from diffusers import DiffusionPipeline |
| |
|
| |
|
| | class DiffusionInferencePipeline(DiffusionPipeline): |
| | def __init__(self, network, scheduler, num_inference_timesteps=1000): |
| | super().__init__() |
| |
|
| | self.register_modules(network=network, scheduler=scheduler) |
| | self.num_inference_timesteps = num_inference_timesteps |
| |
|
| | @torch.inference_mode() |
| | def __call__( |
| | self, |
| | initial_noise: torch.Tensor, |
| | conditioner: torch.Tensor = None, |
| | ): |
| | r""" |
| | Args: |
| | initial_noise: The initial noise to be denoised. |
| | conditioner:The conditioner. |
| | n_inference_steps: The number of denoising steps. More denoising steps |
| | usually lead to a higher quality at the expense of slower inference. |
| | """ |
| |
|
| | mel = initial_noise |
| | batch_size = mel.size(0) |
| | self.scheduler.set_timesteps(self.num_inference_timesteps) |
| |
|
| | for t in self.progress_bar(self.scheduler.timesteps): |
| | timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long) |
| |
|
| | |
| | model_output = self.network(mel, timestep, conditioner) |
| |
|
| | |
| | mel = self.scheduler.step(model_output, t, mel).prev_sample |
| |
|
| | |
| | mel = mel.clamp(-1.0, 1.0) |
| |
|
| | return mel |
| |
|