|
|
import numpy as np |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from tqdm import tqdm |
|
|
from PIL import Image |
|
|
import soundfile as sf |
|
|
from mel_module import Mel |
|
|
|
|
|
class Generator: |
|
|
def __init__(self, config, unet, scheduler, vae, embedding, progress_callback=None): |
|
|
self.config = config |
|
|
self.unet = unet |
|
|
self.scheduler = scheduler |
|
|
self.vae = vae |
|
|
self.embedding = embedding |
|
|
self.progress_callback = progress_callback |
|
|
|
|
|
def tensor_to_mel(self, tensor): |
|
|
denormalize = transforms.Normalize( |
|
|
mean=[-m/s for m, s in zip([0.5], [0.5])], |
|
|
std=[1/s for s in [0.5]] |
|
|
) |
|
|
dn_tensor= denormalize(tensor.detach().cpu()) |
|
|
s = np.array(dn_tensor.squeeze())*255 |
|
|
return Mel(spectrogram=s) |
|
|
|
|
|
def generate(self): |
|
|
with torch.no_grad(): |
|
|
uncond_image = torch.zeros((1, 1, self.config.image_size, self.config.image_size), device=self.config.device) |
|
|
mu, log_var = self.vae.encode(uncond_image) |
|
|
uncond_latent = torch.cat((mu, log_var), dim=1) |
|
|
uncond_latent = uncond_latent.unsqueeze(0) |
|
|
print("uncond", uncond_latent.shape) |
|
|
|
|
|
embeddings = torch.cat([uncond_latent, self.embedding]) |
|
|
|
|
|
generator = torch.Generator(device=self.config.device) |
|
|
|
|
|
noise = torch.randn( |
|
|
(1, 1, self.config.image_size, self.config.image_size), |
|
|
generator=generator, |
|
|
device=self.config.device, |
|
|
) |
|
|
|
|
|
total_steps = len(self.scheduler.timesteps) |
|
|
|
|
|
for i, t in enumerate(self.progress_callback.tqdm(self.scheduler.timesteps)): |
|
|
image_model_input = torch.cat([noise] * 2) |
|
|
image_model_input = self.scheduler.scale_model_input(image_model_input, timestep=t) |
|
|
|
|
|
with torch.no_grad(): |
|
|
noise_pred = self.unet(image_model_input, t, encoder_hidden_states=embeddings).sample |
|
|
noise_pred_uncond, noise_pred_img = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + self.config.guidance_scale * (noise_pred_img - noise_pred_uncond) |
|
|
noise = self.scheduler.step(noise_pred, t, noise).prev_sample |
|
|
|
|
|
image_tensor = noise.squeeze(1) |
|
|
mel = self.tensor_to_mel(image_tensor) |
|
|
mel.save_audio() |
|
|
|