Spaces:
Runtime error
Runtime error
| # from moviepy.video.io.ImageSequenceClip import ImageSequenceClip | |
| # # from moviepy.audio.AudioClip import AudioArrayClip | |
| # from moviepy.audio.io.AudioFileClip import AudioFileClip | |
| from torch.utils.data import DataLoader | |
| from dataset import AudioVideoDataset, LatentDataset | |
| import torch as th | |
| import numpy as np | |
| import einops | |
| from moviepy.audio.io.AudioFileClip import AudioFileClip | |
| from moviepy.video.io.ImageSequenceClip import ImageSequenceClip | |
| from diffusers.models import AutoencoderKL | |
| from converter import denormalize, denormalize_spectrogram | |
| import soundfile as sf | |
| import os | |
| import json | |
| import torch | |
| from tqdm import tqdm | |
| ################################################################################# | |
| # Video Utils # | |
| ################################################################################# | |
| def preprocess_video(video): | |
| # video = 255*(video+1)/2.0 # [-1,1] -> [0,1] -> [0,255] | |
| # video = th.clamp(video, 0, 255).to(dtype=th.uint8, device="cuda") | |
| video = out2img(video) | |
| video = einops.rearrange(video, 't c h w -> t h w c').cpu().numpy() | |
| return video | |
| def preprocess_video_batch(videos): | |
| B = videos.shape[0] | |
| videos_prep = np.empty(B, dtype=np.ndarray) | |
| for b in range(B): | |
| videos_prep[b] = preprocess_video(videos[b]) | |
| videos_prep = np.stack(videos_prep, axis=0) | |
| return videos_prep | |
| def save_latents(video, audio, y, output_path, name_prefix, ext=".pt"): | |
| os.makedirs(output_path, exist_ok=True) | |
| th.save( | |
| { | |
| "video":video, | |
| "audio":audio, | |
| "y":y | |
| }, os.path.join(output_path, name_prefix + ext)) | |
| def save_multimodal(video, audio, output_path, name_prefix, video_fps=10, audio_fps=16000, audio_dir=None): | |
| if not audio_dir: | |
| audio_dir = output_path | |
| #prepare folders | |
| audio_dir = os.path.join(audio_dir, "audio") | |
| os.makedirs(audio_dir, exist_ok=True) | |
| audio_path = os.path.join(audio_dir, name_prefix + "_audio.wav") | |
| video_dir = os.path.join(output_path, "video") | |
| os.makedirs(video_dir, exist_ok=True) | |
| video_path = os.path.join(video_dir, name_prefix + "_video.mp4") | |
| #save audio | |
| sf.write(audio_path, audio, samplerate=audio_fps) | |
| #save video | |
| video = preprocess_video(video) | |
| imgs = [img for img in video] | |
| video_clip = ImageSequenceClip(imgs, fps=video_fps) | |
| audio_clip = AudioFileClip(audio_path) | |
| video_clip = video_clip.with_audio(audio_clip) | |
| video_clip.write_videofile(video_path, video_fps, audio=True, audio_fps=audio_fps) | |
| def get_dataloader(args, logger, sequence_length, train, latents=False): | |
| if latents: | |
| train_set = LatentDataset(args.data_path, train=train) | |
| else: | |
| train_set = AudioVideoDataset( | |
| args.data_path, | |
| train=train, | |
| sample_every_n_frames=1, | |
| resolution=args.image_size, | |
| sequence_length = sequence_length, | |
| audio_channels = 1, | |
| sample_rate=16000, | |
| min_length=1, | |
| ignore_cache=args.ignore_cache, | |
| labeled=args.num_classes > 0, | |
| target_video_fps=args.target_video_fps, | |
| ) | |
| loader = DataLoader( | |
| train_set, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=True | |
| ) | |
| if logger is not None: | |
| logger.info(f'{"Train" if train else "Test"} Dataset contains {len(train_set)}, images ({args.data_path})') | |
| else: | |
| print(f'{"Train" if train else "Test"} Dataset contains {len(train_set)}, images ({args.data_path})') | |
| return loader | |
| def encode_video(video, vae, use_sd_vae = False): | |
| b, t, c, h, w = video.shape | |
| video = einops.rearrange(video, "b t c h w-> (b t) c h w") | |
| if use_sd_vae: | |
| video = vae.encode(video).latent_dist.sample().mul_(0.18215) | |
| else: | |
| video = vae.encode(video)*vae.cfg.scaling_factor | |
| video = einops.rearrange(video, "(b t) c h w -> b t c h w", t=t) | |
| return video | |
| def decode_video(video, vae): | |
| b = video.shape[0] | |
| video_decoded = [] | |
| video = einops.rearrange(video, "b t c h w -> (b t) c h w") | |
| #use minibatch to avoid memory error | |
| for i in range(0, video.shape[0], b): | |
| if isinstance(vae, AutoencoderKL): | |
| video_decoded.append(vae.decode(video[i:i+b] / 0.18215).sample.detach().cpu()) | |
| else: | |
| video_decoded.append(vae.decode(video[i:i+b] / vae.cfg.scaling_factor).detach().cpu()) | |
| video = torch.cat(video_decoded, dim=0) | |
| video = einops.rearrange(video, "(b t) c h w ->b t c h w",b=b) | |
| return video | |
| def generate_sample(vae, | |
| rectified_flow, | |
| forward_fn, | |
| video_length, | |
| video_latent_size, | |
| audio_latent_size, | |
| y, | |
| cfg_scale, | |
| device): | |
| with torch.no_grad(): | |
| v_z = torch.randn(video_latent_size, device=device)*rectified_flow.noise_scale | |
| a_z = torch.randn(audio_latent_size, device=device)*rectified_flow.noise_scale | |
| model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y) | |
| sample_fn = rectified_flow.sample( | |
| forward_fn, v_z, a_z, model_kwargs=model_kwargs, progress=True)() | |
| video = [] | |
| audio = [] | |
| for _ in tqdm(range(video_length), desc="Generating frames"): | |
| video_samples, audio_samples = next(sample_fn) | |
| video.append(video_samples) | |
| audio.append(audio_samples) | |
| video = torch.stack(video, dim=1) | |
| audio = torch.stack(audio, dim=1) | |
| video = decode_video(video, vae) | |
| audio = einops.rearrange(audio, "B T C N F -> B C N (T F)") | |
| return video, audio | |
| def generate_sample_a2v(vae, | |
| rectified_flow, | |
| forward_fn, | |
| video_length, | |
| video_latent_size, | |
| audio, | |
| y, | |
| device, | |
| cfg_scale=1, | |
| scale=1): | |
| v_z = torch.randn(video_latent_size, device=device)*rectified_flow.noise_scale | |
| model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y) | |
| sample_fn = rectified_flow.sample_a2v( | |
| forward_fn, v_z, audio, model_kwargs=model_kwargs, scale=scale, progress=True)() | |
| video = [] | |
| for i in tqdm(range(video_length), desc="Generating frames"): | |
| video_samples = next(sample_fn) | |
| video.append(video_samples) | |
| video = torch.stack(video, dim=1) | |
| video = decode_video(video, vae) | |
| audio = einops.rearrange(audio, "B T C N F -> B C N (T F)") | |
| return video, audio | |
| def generate_sample_v2a(vae, | |
| rectified_flow, | |
| forward_fn, | |
| video_length, | |
| video, | |
| audio_latent_size, | |
| y, | |
| device, | |
| cfg_scale=1, | |
| scale=1): | |
| a_z = torch.randn(audio_latent_size, device=device)*rectified_flow.noise_scale | |
| model_kwargs = dict(y=y, cfg_scale=cfg_scale) if cfg_scale else dict(y=y) | |
| sample_fn = rectified_flow.sample_v2a( | |
| forward_fn, video, a_z, model_kwargs=model_kwargs, scale=scale, progress=True)() | |
| audio = [] | |
| for i in tqdm(range(video_length), desc="Generating frames"): | |
| audio_samples = next(sample_fn) | |
| audio.append(audio_samples) | |
| audio = torch.stack(audio, dim=1) | |
| video = decode_video(video, vae) | |
| audio = einops.rearrange(audio, "B T C N F -> B C N (T F)") | |
| return video, audio | |
| def dict_to_json(path, args): | |
| with open(path, 'w') as f: | |
| json.dump(args.__dict__, f, indent=2) | |
| def json_to_dict(path, args): | |
| with open(path, 'r') as f: | |
| args.__dict__ = json.load(f) | |
| return args | |
| def log_args(args, logger): | |
| text = "" | |
| for k, v in vars(args).items(): | |
| text += f'{k}={v}\n' | |
| logger.info(f"##### ARGS #####\n{text}") | |
| def out2img(samples): | |
| return th.clamp(127.5 * samples + 128.0, 0, 255).to( | |
| dtype=th.uint8 | |
| ).cuda() | |
| def get_gpu_usage(): | |
| device = th.device('cuda:0') | |
| free, total = th.cuda.mem_get_info(device) | |
| mem_used_MB = (total - free) / 1024 ** 2 | |
| return mem_used_MB | |
| def get_wavs(norm_spec, vocoder, audio_scale, device): | |
| norm_spec = norm_spec.squeeze(1) | |
| norm_spec = norm_spec / audio_scale | |
| post_norm_spec = denormalize(norm_spec).to(device) | |
| raw_chunk_spec = denormalize_spectrogram(post_norm_spec) | |
| wavs = vocoder.inference(raw_chunk_spec) | |
| return wavs |