FashionFlow / src /training /logging.py
tasin
init
f075308
import os
from typing import List, Callable, Optional, Dict
from multiprocessing.pool import ThreadPool
from PIL import Image
import torch
from torch import Tensor
import numpy as np
import cv2
from tqdm import tqdm
from torchvision import utils
import torchvision.transforms.functional as TVF
#----------------------------------------------------------------------------
@torch.no_grad()
def generate_videos(
G: Callable, z: Tensor, c: Tensor, ts: Tensor, motion_z: Optional[Tensor]=None,
noise_mode='const', truncation_psi=1.0, verbose: bool=False, as_grids: bool=False, batch_size_num_frames: int=100) -> Tensor:
assert len(ts) == len(z) == len(c), f"Wrong shape: {ts.shape}, {z.shape}, {c.shape}"
assert ts.ndim == 2, f"Wrong shape: {ts.shape}"
G.eval()
videos = []
if c.shape[1] > 0 and truncation_psi < 1:
num_ws_to_average = 1000
c_for_avg = c.repeat_interleave(num_ws_to_average, dim=0) # [num_classes * num_ws_to_average, num_classes]
z_for_avg = torch.randn(c_for_avg.shape[0], G.z_dim, device=z.device) # [num_classes * num_ws_to_average, z_dim]
w = G.mapping(z_for_avg, c=c_for_avg)[:, 0] # [num_classes * num_ws_to_average, w_dim]
w_avg = w.view(-1, num_ws_to_average, G.w_dim).mean(dim=1) # [num_classes, w_dim]
iters = range(len(z))
iters = tqdm(iters, desc='Generating videos') if verbose else iters
if motion_z is None and not G.synthesis.motion_encoder is None:
motion_z = G.synthesis.motion_encoder(c=c, t=ts)['motion_z'] # [...any...]
for video_idx in iters:
curr_video = []
for curr_ts in ts[[video_idx]].split(batch_size_num_frames, dim=1):
curr_z = z[[video_idx]] # [1, z_dim]
curr_c = c[[video_idx]] # [1, c_dim]
curr_motion_z = motion_z[[video_idx]]
if curr_c.shape[1] > 0 and truncation_psi < 1:
curr_w = G.mapping(curr_z, c=curr_c, truncation_psi=1) # [1, num_ws, w_dim]
curr_w = truncation_psi * curr_w + (1 - truncation_psi) * w_avg.unsqueeze(1) # [1, num_ws, w_dim]
out = G.synthesis(
ws=curr_w,
c=curr_c,
t=curr_ts,
motion_z=curr_motion_z,
noise_mode=noise_mode) # [1 * curr_num_frames, 3, h, w]
else:
out = G(
z=curr_z,
c=curr_c,
t=curr_ts,
motion_z=curr_motion_z,
truncation_psi=truncation_psi,
noise_mode=noise_mode) # [1 * curr_num_frames, 3, h, w]
out = (out * 0.5 + 0.5).clamp(0, 1).cpu() # [1 * curr_num_frames, 3, h, w]
curr_video.append(out)
videos.append(torch.cat(curr_video, dim=0))
videos = torch.stack(videos) # [len(z), video_len, c, h, w]
if as_grids:
frame_grids = videos.permute(1, 0, 2, 3, 4) # [video_len, len(z), c, h, w]
frame_grids = [utils.make_grid(fs, nrow=int(np.sqrt(len(z)))) for fs in frame_grids] # [video_len, 3, grid_h, grid_w]
return torch.stack(frame_grids)
else:
return videos
#----------------------------------------------------------------------------
def run_batchwise(fn: Callable, data_kwargs: Dict[str, Tensor], batch_size: int, **kwargs) -> Tensor:
data_kwargs = {k: v for k, v in data_kwargs.items() if not v is None}
seq_len = len(data_kwargs[list(data_kwargs.keys())[0]])
result = []
for i in range((seq_len + batch_size - 1) // batch_size):
curr_data_kwargs = {k: d[i * batch_size: (i+1) * batch_size] for k, d in data_kwargs.items()}
result.append(fn(**curr_data_kwargs, **kwargs))
return torch.cat(result, dim=0)
#----------------------------------------------------------------------------
def save_video_frames_as_mp4(frames: List[Tensor], fps: int, save_path: os.PathLike, verbose: bool=False):
# Load data
frame_h, frame_w = frames[0].shape[1:]
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
video = cv2.VideoWriter(save_path, fourcc, fps, (frame_w, frame_h))
frames = tqdm(frames, desc='Saving videos') if verbose else frames
for frame in frames:
assert frame.shape[0] == 3, "RGBA/grayscale images are not supported"
frame = np.array(TVF.to_pil_image(frame))
video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
# Uncomment this line to release the memory.
# It didn't work for me on centos and complained about installing additional libraries (which requires root access)
# cv2.destroyAllWindows()
video.release()
#----------------------------------------------------------------------------
def save_video_frames_as_frames(frames: List[Tensor], save_dir: os.PathLike, time_offset: int=0):
os.makedirs(save_dir, exist_ok=True)
for i, frame in enumerate(frames):
save_path = os.path.join(save_dir, f'{i + time_offset:06d}.jpg')
TVF.to_pil_image(frame).save(save_path, q=95)
#----------------------------------------------------------------------------
def save_video_frames_as_frames_parallel(frames: List[np.ndarray], save_dir: os.PathLike, time_offset: int=0, num_processes: int=1):
assert num_processes > 1, "Use `save_video_frames_as_frames` if you do not plan to use num_processes > 1."
os.makedirs(save_dir, exist_ok=True)
# We are fine with the ThreadPool instead of Pool since most of the work is I/O
pool = ThreadPool(processes=num_processes)
save_paths = [os.path.join(save_dir, f'{i + time_offset:06d}.jpg') for i in range(len(frames))]
pool.map(save_jpg_mp_proxy, [(f, p) for f, p in zip(frames, save_paths)])
#----------------------------------------------------------------------------
def save_jpg_mp_proxy(args):
return save_jpg(*args)
#----------------------------------------------------------------------------
def save_jpg(x: np.ndarray, save_path: os.PathLike):
Image.fromarray(x).save(save_path, q=95)
#----------------------------------------------------------------------------