Spaces:
Runtime error
Runtime error
| # pip install -e git+https://github.com/CompVis/stable-diffusion.git@master | |
| # pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers | |
| # convert_original_stable_diffusion_to_diffusers.py | |
| # TODO | |
| # grayscale | |
| # log audio | |
| # convert to huggingface / train huggingface | |
| import os | |
| import argparse | |
| import torch | |
| import torchvision | |
| import numpy as np | |
| from PIL import Image | |
| import pytorch_lightning as pl | |
| from omegaconf import OmegaConf | |
| from datasets import load_dataset | |
| from librosa.util import normalize | |
| from ldm.util import instantiate_from_config | |
| from pytorch_lightning.trainer import Trainer | |
| from torch.utils.data import DataLoader, Dataset | |
| from pytorch_lightning.callbacks import Callback, ModelCheckpoint | |
| from audiodiffusion.mel import Mel | |
| class AudioDiffusion(Dataset): | |
| def __init__(self, model_id): | |
| super().__init__() | |
| self.hf_dataset = load_dataset(model_id)['train'] | |
| def __len__(self): | |
| return len(self.hf_dataset) | |
| def __getitem__(self, idx): | |
| image = self.hf_dataset[idx]['image'].convert('RGB') | |
| image = np.frombuffer(image.tobytes(), dtype="uint8").reshape( | |
| (image.height, image.width, 3)) | |
| image = ((image / 255) * 2 - 1) | |
| return {'image': image} | |
| class AudioDiffusionDataModule(pl.LightningDataModule): | |
| def __init__(self, model_id, batch_size): | |
| super().__init__() | |
| self.batch_size = batch_size | |
| self.dataset = AudioDiffusion(model_id) | |
| self.num_workers = 1 | |
| def train_dataloader(self): | |
| return DataLoader(self.dataset, | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers) | |
| # from https://github.com/CompVis/stable-diffusion/blob/main/main.py | |
| class ImageLogger(Callback): | |
| def __init__(self, | |
| batch_frequency, | |
| max_images, | |
| clamp=True, | |
| increase_log_steps=True, | |
| rescale=True, | |
| disabled=False, | |
| log_on_batch_idx=False, | |
| log_first_step=False, | |
| log_images_kwargs=None, | |
| resolution=256, | |
| hop_length=512): | |
| super().__init__() | |
| self.mel = Mel(x_res=resolution, | |
| y_res=resolution, | |
| hop_length=hop_length) | |
| self.rescale = rescale | |
| self.batch_freq = batch_frequency | |
| self.max_images = max_images | |
| self.logger_log_images = { | |
| pl.loggers.TensorBoardLogger: self._testtube, | |
| } | |
| self.log_steps = [ | |
| 2**n for n in range(int(np.log2(self.batch_freq)) + 1) | |
| ] | |
| if not increase_log_steps: | |
| self.log_steps = [self.batch_freq] | |
| self.clamp = clamp | |
| self.disabled = disabled | |
| self.log_on_batch_idx = log_on_batch_idx | |
| self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} | |
| self.log_first_step = log_first_step | |
| #@rank_zero_only | |
| def _testtube(self, pl_module, images, batch_idx, split): | |
| for k in images: | |
| images_ = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w | |
| grid = torchvision.utils.make_grid(images_) | |
| tag = f"{split}/{k}" | |
| pl_module.logger.experiment.add_image( | |
| tag, grid, global_step=pl_module.global_step) | |
| for _, image in enumerate(images_): | |
| image = (images_.numpy() * | |
| 255).round().astype("uint8").transpose(0, 2, 3, 1) | |
| audio = self.mel.image_to_audio( | |
| Image.fromarray(image[0], mode='RGB').convert('L')) | |
| pl_module.logger.experiment.add_audio( | |
| tag + f"/{_}", | |
| normalize(audio), | |
| global_step=pl_module.global_step, | |
| sample_rate=self.mel.get_sample_rate()) | |
| #@rank_zero_only | |
| def log_local(self, save_dir, split, images, global_step, current_epoch, | |
| batch_idx): | |
| root = os.path.join(save_dir, "images", split) | |
| for k in images: | |
| grid = torchvision.utils.make_grid(images[k], nrow=4) | |
| if self.rescale: | |
| grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w | |
| grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
| grid = grid.numpy() | |
| grid = (grid * 255).astype(np.uint8) | |
| filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( | |
| k, global_step, current_epoch, batch_idx) | |
| path = os.path.join(root, filename) | |
| os.makedirs(os.path.split(path)[0], exist_ok=True) | |
| Image.fromarray(grid).save(path) | |
| def log_img(self, pl_module, batch, batch_idx, split="train"): | |
| check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step | |
| if (self.check_frequency(check_idx) | |
| and # batch_idx % self.batch_freq == 0 | |
| hasattr(pl_module, "log_images") and | |
| callable(pl_module.log_images) and self.max_images > 0): | |
| logger = type(pl_module.logger) | |
| is_train = pl_module.training | |
| if is_train: | |
| pl_module.eval() | |
| with torch.no_grad(): | |
| images = pl_module.log_images(batch, | |
| split=split, | |
| **self.log_images_kwargs) | |
| for k in images: | |
| N = min(images[k].shape[0], self.max_images) | |
| images[k] = images[k][:N] | |
| if isinstance(images[k], torch.Tensor): | |
| images[k] = images[k].detach().cpu() | |
| if self.clamp: | |
| images[k] = torch.clamp(images[k], -1., 1.) | |
| #self.log_local(pl_module.logger.save_dir, split, images, | |
| # pl_module.global_step, pl_module.current_epoch, | |
| # batch_idx) | |
| logger_log_images = self.logger_log_images.get( | |
| logger, lambda *args, **kwargs: None) | |
| logger_log_images(pl_module, images, pl_module.global_step, split) | |
| if is_train: | |
| pl_module.train() | |
| def check_frequency(self, check_idx): | |
| if ((check_idx % self.batch_freq) == 0 or | |
| (check_idx in self.log_steps)) and (check_idx > 0 | |
| or self.log_first_step): | |
| try: | |
| self.log_steps.pop(0) | |
| except IndexError as e: | |
| #print(e) | |
| pass | |
| return True | |
| return False | |
| def on_train_batch_end(self, trainer, pl_module, outputs, batch, | |
| batch_idx): | |
| if not self.disabled and (pl_module.global_step > 0 | |
| or self.log_first_step): | |
| self.log_img(pl_module, batch, batch_idx, split="train") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Train VAE using ldm.") | |
| parser.add_argument("--batch_size", type=int, default=1) | |
| args = parser.parse_args() | |
| config = OmegaConf.load('ldm_autoencoder_kl.yaml') | |
| lightning_config = config.pop("lightning", OmegaConf.create()) | |
| trainer_config = lightning_config.get("trainer", OmegaConf.create()) | |
| trainer_opt = argparse.Namespace(**trainer_config) | |
| trainer = Trainer.from_argparse_args( | |
| trainer_opt, | |
| callbacks=[ | |
| ImageLogger(batch_frequency=1000, | |
| max_images=8, | |
| increase_log_steps=False, | |
| log_on_batch_idx=True), | |
| ModelCheckpoint(dirpath='checkpoints', | |
| filename='{epoch:06}', | |
| verbose=True, | |
| save_last=True) | |
| ]) | |
| model = instantiate_from_config(config.model) | |
| model.learning_rate = config.model.base_learning_rate | |
| data = AudioDiffusionDataModule('teticio/audio-diffusion-256', | |
| batch_size=args.batch_size) | |
| trainer.fit(model, data) | |