Spaces:
Running
Running
| import torch | |
| from diffusers.loaders import AttnProcsLayers | |
| from modules.BEATs.BEATs import BEATs, BEATsConfig | |
| from modules.AudioToken.embedder import FGAEmbedder | |
| from diffusers import AutoencoderKL, UNet2DConditionModel | |
| from diffusers.models.attention_processor import LoRAAttnProcessor | |
| class AudioTokenWrapper(torch.nn.Module): | |
| """Simple wrapper module for Stable Diffusion that holds all the models together""" | |
| def __init__( | |
| self, | |
| args, | |
| accelerator, | |
| ): | |
| super().__init__() | |
| # Load scheduler and models | |
| from modules.clip_text_model.modeling_clip import CLIPTextModel | |
| self.text_encoder = CLIPTextModel.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision | |
| ) | |
| self.unet = UNet2DConditionModel.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision | |
| ) | |
| self.vae = AutoencoderKL.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision | |
| ) | |
| checkpoint = torch.load( | |
| 'models/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt') | |
| cfg = BEATsConfig(checkpoint['cfg']) | |
| self.aud_encoder = BEATs(cfg) | |
| self.aud_encoder.load_state_dict(checkpoint['model']) | |
| self.aud_encoder.predictor = None | |
| input_size = 768 * 3 | |
| if args.pretrained_model_name_or_path == "CompVis/stable-diffusion-v1-4": | |
| self.embedder = FGAEmbedder(input_size=input_size, output_size=768) | |
| else: | |
| self.embedder = FGAEmbedder(input_size=input_size, output_size=1024) | |
| self.vae.eval() | |
| self.unet.eval() | |
| self.text_encoder.eval() | |
| self.aud_encoder.eval() | |
| if 'lora' in args and args.lora: | |
| # Set correct lora layers | |
| lora_attn_procs = {} | |
| for name in self.unet.attn_processors.keys(): | |
| cross_attention_dim = None if name.endswith( | |
| "attn1.processor") else self.unet.config.cross_attention_dim | |
| if name.startswith("mid_block"): | |
| hidden_size = self.unet.config.block_out_channels[-1] | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| hidden_size = self.unet.config.block_out_channels[block_id] | |
| lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, | |
| cross_attention_dim=cross_attention_dim) | |
| self.unet.set_attn_processor(lora_attn_procs) | |
| self.lora_layers = AttnProcsLayers(self.unet.attn_processors) | |
| if args.data_set == 'train': | |
| # Freeze vae, unet, text_enc and aud_encoder | |
| self.vae.requires_grad_(False) | |
| self.unet.requires_grad_(False) | |
| self.text_encoder.requires_grad_(False) | |
| self.aud_encoder.requires_grad_(False) | |
| self.embedder.requires_grad_(True) | |
| self.embedder.train() | |
| if 'lora' in args and args.lora: | |
| self.unet.train() | |
| if args.data_set == 'test': | |
| from transformers import CLIPTextModel | |
| self.text_encoder = CLIPTextModel.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision | |
| ) | |
| self.embedder.eval() | |
| embedder_learned_embeds = args.learned_embeds | |
| self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=accelerator.device)) | |
| if 'lora' in args and args.lora: | |
| self.lora_layers.eval() | |
| lora_layers_learned_embeds = args.lora_learned_embeds | |
| self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=accelerator.device)) | |
| self.unet.load_attn_procs(lora_layers_learned_embeds) | |