| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler |
| |
|
| | from models.svc.base import SVCInference |
| | from models.svc.diffusion.diffusion_inference_pipeline import DiffusionInferencePipeline |
| | from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper |
| | from modules.encoder.condition_encoder import ConditionEncoder |
| |
|
| |
|
| | class DiffusionInference(SVCInference): |
| | def __init__(self, args=None, cfg=None, infer_type="from_dataset"): |
| | SVCInference.__init__(self, args, cfg, infer_type) |
| |
|
| | settings = { |
| | **cfg.model.diffusion.scheduler_settings, |
| | **cfg.inference.diffusion.scheduler_settings, |
| | } |
| | settings.pop("num_inference_timesteps") |
| |
|
| | if cfg.inference.diffusion.scheduler.lower() == "ddpm": |
| | self.scheduler = DDPMScheduler(**settings) |
| | self.logger.info("Using DDPM scheduler.") |
| | elif cfg.inference.diffusion.scheduler.lower() == "ddim": |
| | self.scheduler = DDIMScheduler(**settings) |
| | self.logger.info("Using DDIM scheduler.") |
| | elif cfg.inference.diffusion.scheduler.lower() == "pndm": |
| | self.scheduler = PNDMScheduler(**settings) |
| | self.logger.info("Using PNDM scheduler.") |
| | else: |
| | raise NotImplementedError( |
| | "Unsupported scheduler type: {}".format( |
| | cfg.inference.diffusion.scheduler.lower() |
| | ) |
| | ) |
| |
|
| | self.pipeline = DiffusionInferencePipeline( |
| | self.model[1], |
| | self.scheduler, |
| | args.diffusion_inference_steps, |
| | ) |
| |
|
| | def _build_model(self): |
| | self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min |
| | self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max |
| | self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) |
| | self.acoustic_mapper = DiffusionWrapper(self.cfg) |
| | model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) |
| | return model |
| |
|
| | def _inference_each_batch(self, batch_data): |
| | device = self.accelerator.device |
| | for k, v in batch_data.items(): |
| | batch_data[k] = v.to(device) |
| |
|
| | conditioner = self.model[0](batch_data) |
| | noise = torch.randn_like(batch_data["mel"], device=device) |
| | y_pred = self.pipeline(noise, conditioner) |
| | return y_pred |
| |
|