Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Callable, Optional, Dict | |
| from multiprocessing.pool import ThreadPool | |
| from PIL import Image | |
| import torch | |
| from torch import Tensor | |
| import numpy as np | |
| import cv2 | |
| from tqdm import tqdm | |
| from torchvision import utils | |
| import torchvision.transforms.functional as TVF | |
| #---------------------------------------------------------------------------- | |
| def generate_videos( | |
| G: Callable, z: Tensor, c: Tensor, ts: Tensor, motion_z: Optional[Tensor]=None, | |
| noise_mode='const', truncation_psi=1.0, verbose: bool=False, as_grids: bool=False, batch_size_num_frames: int=100) -> Tensor: | |
| assert len(ts) == len(z) == len(c), f"Wrong shape: {ts.shape}, {z.shape}, {c.shape}" | |
| assert ts.ndim == 2, f"Wrong shape: {ts.shape}" | |
| G.eval() | |
| videos = [] | |
| if c.shape[1] > 0 and truncation_psi < 1: | |
| num_ws_to_average = 1000 | |
| c_for_avg = c.repeat_interleave(num_ws_to_average, dim=0) # [num_classes * num_ws_to_average, num_classes] | |
| z_for_avg = torch.randn(c_for_avg.shape[0], G.z_dim, device=z.device) # [num_classes * num_ws_to_average, z_dim] | |
| w = G.mapping(z_for_avg, c=c_for_avg)[:, 0] # [num_classes * num_ws_to_average, w_dim] | |
| w_avg = w.view(-1, num_ws_to_average, G.w_dim).mean(dim=1) # [num_classes, w_dim] | |
| iters = range(len(z)) | |
| iters = tqdm(iters, desc='Generating videos') if verbose else iters | |
| if motion_z is None and not G.synthesis.motion_encoder is None: | |
| motion_z = G.synthesis.motion_encoder(c=c, t=ts)['motion_z'] # [...any...] | |
| for video_idx in iters: | |
| curr_video = [] | |
| for curr_ts in ts[[video_idx]].split(batch_size_num_frames, dim=1): | |
| curr_z = z[[video_idx]] # [1, z_dim] | |
| curr_c = c[[video_idx]] # [1, c_dim] | |
| curr_motion_z = motion_z[[video_idx]] | |
| if curr_c.shape[1] > 0 and truncation_psi < 1: | |
| curr_w = G.mapping(curr_z, c=curr_c, truncation_psi=1) # [1, num_ws, w_dim] | |
| curr_w = truncation_psi * curr_w + (1 - truncation_psi) * w_avg.unsqueeze(1) # [1, num_ws, w_dim] | |
| out = G.synthesis( | |
| ws=curr_w, | |
| c=curr_c, | |
| t=curr_ts, | |
| motion_z=curr_motion_z, | |
| noise_mode=noise_mode) # [1 * curr_num_frames, 3, h, w] | |
| else: | |
| out = G( | |
| z=curr_z, | |
| c=curr_c, | |
| t=curr_ts, | |
| motion_z=curr_motion_z, | |
| truncation_psi=truncation_psi, | |
| noise_mode=noise_mode) # [1 * curr_num_frames, 3, h, w] | |
| out = (out * 0.5 + 0.5).clamp(0, 1).cpu() # [1 * curr_num_frames, 3, h, w] | |
| curr_video.append(out) | |
| videos.append(torch.cat(curr_video, dim=0)) | |
| videos = torch.stack(videos) # [len(z), video_len, c, h, w] | |
| if as_grids: | |
| frame_grids = videos.permute(1, 0, 2, 3, 4) # [video_len, len(z), c, h, w] | |
| frame_grids = [utils.make_grid(fs, nrow=int(np.sqrt(len(z)))) for fs in frame_grids] # [video_len, 3, grid_h, grid_w] | |
| return torch.stack(frame_grids) | |
| else: | |
| return videos | |
| #---------------------------------------------------------------------------- | |
| def run_batchwise(fn: Callable, data_kwargs: Dict[str, Tensor], batch_size: int, **kwargs) -> Tensor: | |
| data_kwargs = {k: v for k, v in data_kwargs.items() if not v is None} | |
| seq_len = len(data_kwargs[list(data_kwargs.keys())[0]]) | |
| result = [] | |
| for i in range((seq_len + batch_size - 1) // batch_size): | |
| curr_data_kwargs = {k: d[i * batch_size: (i+1) * batch_size] for k, d in data_kwargs.items()} | |
| result.append(fn(**curr_data_kwargs, **kwargs)) | |
| return torch.cat(result, dim=0) | |
| #---------------------------------------------------------------------------- | |
| def save_video_frames_as_mp4(frames: List[Tensor], fps: int, save_path: os.PathLike, verbose: bool=False): | |
| # Load data | |
| frame_h, frame_w = frames[0].shape[1:] | |
| fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') | |
| video = cv2.VideoWriter(save_path, fourcc, fps, (frame_w, frame_h)) | |
| frames = tqdm(frames, desc='Saving videos') if verbose else frames | |
| for frame in frames: | |
| assert frame.shape[0] == 3, "RGBA/grayscale images are not supported" | |
| frame = np.array(TVF.to_pil_image(frame)) | |
| video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
| # Uncomment this line to release the memory. | |
| # It didn't work for me on centos and complained about installing additional libraries (which requires root access) | |
| # cv2.destroyAllWindows() | |
| video.release() | |
| #---------------------------------------------------------------------------- | |
| def save_video_frames_as_frames(frames: List[Tensor], save_dir: os.PathLike, time_offset: int=0): | |
| os.makedirs(save_dir, exist_ok=True) | |
| for i, frame in enumerate(frames): | |
| save_path = os.path.join(save_dir, f'{i + time_offset:06d}.jpg') | |
| TVF.to_pil_image(frame).save(save_path, q=95) | |
| #---------------------------------------------------------------------------- | |
| def save_video_frames_as_frames_parallel(frames: List[np.ndarray], save_dir: os.PathLike, time_offset: int=0, num_processes: int=1): | |
| assert num_processes > 1, "Use `save_video_frames_as_frames` if you do not plan to use num_processes > 1." | |
| os.makedirs(save_dir, exist_ok=True) | |
| # We are fine with the ThreadPool instead of Pool since most of the work is I/O | |
| pool = ThreadPool(processes=num_processes) | |
| save_paths = [os.path.join(save_dir, f'{i + time_offset:06d}.jpg') for i in range(len(frames))] | |
| pool.map(save_jpg_mp_proxy, [(f, p) for f, p in zip(frames, save_paths)]) | |
| #---------------------------------------------------------------------------- | |
| def save_jpg_mp_proxy(args): | |
| return save_jpg(*args) | |
| #---------------------------------------------------------------------------- | |
| def save_jpg(x: np.ndarray, save_path: os.PathLike): | |
| Image.fromarray(x).save(save_path, q=95) | |
| #---------------------------------------------------------------------------- | |