File size: 6,110 Bytes
f075308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
from typing import List, Callable, Optional, Dict
from multiprocessing.pool import ThreadPool

from PIL import Image
import torch
from torch import Tensor
import numpy as np
import cv2
from tqdm import tqdm
from torchvision import utils
import torchvision.transforms.functional as TVF

#----------------------------------------------------------------------------

@torch.no_grad()
def generate_videos(
    G: Callable, z: Tensor, c: Tensor, ts: Tensor, motion_z: Optional[Tensor]=None,
    noise_mode='const', truncation_psi=1.0, verbose: bool=False, as_grids: bool=False, batch_size_num_frames: int=100) -> Tensor:

    assert len(ts) == len(z) == len(c), f"Wrong shape: {ts.shape}, {z.shape}, {c.shape}"
    assert ts.ndim == 2, f"Wrong shape: {ts.shape}"

    G.eval()
    videos = []

    if c.shape[1] > 0 and truncation_psi < 1:
        num_ws_to_average = 1000
        c_for_avg = c.repeat_interleave(num_ws_to_average, dim=0) # [num_classes * num_ws_to_average, num_classes]
        z_for_avg = torch.randn(c_for_avg.shape[0], G.z_dim, device=z.device) # [num_classes * num_ws_to_average, z_dim]
        w = G.mapping(z_for_avg, c=c_for_avg)[:, 0] # [num_classes * num_ws_to_average, w_dim]
        w_avg = w.view(-1, num_ws_to_average, G.w_dim).mean(dim=1) # [num_classes, w_dim]

    iters = range(len(z))
    iters = tqdm(iters, desc='Generating videos') if verbose else iters

    if motion_z is None and not G.synthesis.motion_encoder is None:
        motion_z = G.synthesis.motion_encoder(c=c, t=ts)['motion_z'] # [...any...]

    for video_idx in iters:
        curr_video = []


        for curr_ts in ts[[video_idx]].split(batch_size_num_frames, dim=1):
            curr_z = z[[video_idx]] # [1, z_dim]
            curr_c = c[[video_idx]] # [1, c_dim]
            curr_motion_z = motion_z[[video_idx]]

            if curr_c.shape[1] > 0 and truncation_psi < 1:
                curr_w = G.mapping(curr_z, c=curr_c, truncation_psi=1) # [1, num_ws, w_dim]
                curr_w = truncation_psi * curr_w + (1 - truncation_psi) * w_avg.unsqueeze(1) # [1, num_ws, w_dim]
                out = G.synthesis(
                    ws=curr_w,
                    c=curr_c,
                    t=curr_ts,
                    motion_z=curr_motion_z,
                    noise_mode=noise_mode) # [1 * curr_num_frames, 3, h, w]
            else:
                out = G(
                    z=curr_z,
                    c=curr_c,
                    t=curr_ts,
                    motion_z=curr_motion_z,
                    truncation_psi=truncation_psi,
                    noise_mode=noise_mode) # [1 * curr_num_frames, 3, h, w]

            out = (out * 0.5 + 0.5).clamp(0, 1).cpu() # [1 * curr_num_frames, 3, h, w]
            curr_video.append(out)

        videos.append(torch.cat(curr_video, dim=0))

    videos = torch.stack(videos) # [len(z), video_len, c, h, w]

    if as_grids:
        frame_grids = videos.permute(1, 0, 2, 3, 4) # [video_len, len(z), c, h, w]
        frame_grids = [utils.make_grid(fs, nrow=int(np.sqrt(len(z)))) for fs in frame_grids] # [video_len, 3, grid_h, grid_w]

        return torch.stack(frame_grids)
    else:
        return videos

#----------------------------------------------------------------------------

def run_batchwise(fn: Callable, data_kwargs: Dict[str, Tensor], batch_size: int, **kwargs) -> Tensor:
    data_kwargs = {k: v for k, v in data_kwargs.items() if not v is None}
    seq_len = len(data_kwargs[list(data_kwargs.keys())[0]])
    result = []

    for i in range((seq_len + batch_size - 1) // batch_size):
        curr_data_kwargs = {k: d[i * batch_size: (i+1) * batch_size] for k, d in data_kwargs.items()}
        result.append(fn(**curr_data_kwargs, **kwargs))

    return torch.cat(result, dim=0)

#----------------------------------------------------------------------------

def save_video_frames_as_mp4(frames: List[Tensor], fps: int, save_path: os.PathLike, verbose: bool=False):
    # Load data
    frame_h, frame_w = frames[0].shape[1:]
    fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
    video = cv2.VideoWriter(save_path, fourcc, fps, (frame_w, frame_h))
    frames = tqdm(frames, desc='Saving videos') if verbose else frames
    for frame in frames:
        assert frame.shape[0] == 3, "RGBA/grayscale images are not supported"
        frame = np.array(TVF.to_pil_image(frame))
        video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    # Uncomment this line to release the memory.
    # It didn't work for me on centos and complained about installing additional libraries (which requires root access)
    # cv2.destroyAllWindows()
    video.release()

#----------------------------------------------------------------------------

def save_video_frames_as_frames(frames: List[Tensor], save_dir: os.PathLike, time_offset: int=0):
    os.makedirs(save_dir, exist_ok=True)

    for i, frame in enumerate(frames):
        save_path = os.path.join(save_dir, f'{i + time_offset:06d}.jpg')
        TVF.to_pil_image(frame).save(save_path, q=95)

#----------------------------------------------------------------------------

def save_video_frames_as_frames_parallel(frames: List[np.ndarray], save_dir: os.PathLike, time_offset: int=0, num_processes: int=1):
    assert num_processes > 1, "Use `save_video_frames_as_frames` if you do not plan to use num_processes > 1."
    os.makedirs(save_dir, exist_ok=True)
    # We are fine with the ThreadPool instead of Pool since most of the work is I/O
    pool = ThreadPool(processes=num_processes)
    save_paths = [os.path.join(save_dir, f'{i + time_offset:06d}.jpg') for i in range(len(frames))]
    pool.map(save_jpg_mp_proxy, [(f, p) for f, p in zip(frames, save_paths)])

#----------------------------------------------------------------------------

def save_jpg_mp_proxy(args):
    return save_jpg(*args)

#----------------------------------------------------------------------------

def save_jpg(x: np.ndarray, save_path: os.PathLike):
    Image.fromarray(x).save(save_path, q=95)

#----------------------------------------------------------------------------