| import os |
| import argparse |
| from PIL import Image |
| from glob import glob |
| import numpy as np |
| import json |
| import torch |
| import torchvision |
| from torch.nn import functional as F |
| from matplotlib import colormaps |
| import math |
| import scipy |
|
|
|
|
| def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True): |
| H, W = height, width |
| S = shape if shape else [] |
| if align_corners: |
| x = torch.linspace(0, 1, W, device=device) |
| y = torch.linspace(0, 1, H, device=device) |
| if not normalize: |
| x = x * (W - 1) |
| y = y * (H - 1) |
| else: |
| x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device) |
| y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device) |
| if not normalize: |
| x = x * W |
| y = y * H |
| x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W] |
| x = x.view(*x_view).expand(*exp) |
| y = y.view(*y_view).expand(*exp) |
| grid = torch.stack([x, y], dim=-1) |
| if dtype == "numpy": |
| grid = grid.numpy() |
| return grid |
|
|
| def translation(frame, dx, dy, pad_value): |
| C, H, W = frame.shape |
| grid = get_grid(H, W, device=frame.device) |
| grid[..., 0] = grid[..., 0] - (dx / (W - 1)) |
| grid[..., 1] = grid[..., 1] - (dy / (H - 1)) |
| frame = frame - pad_value |
| frame = torch.nn.functional.grid_sample(frame[None], grid[None] * 2 - 1, mode='bilinear', align_corners=True)[0] |
| frame = frame + pad_value |
| return frame |
|
|
|
|
| def project(pos, t, time_steps, heigh, width): |
| T, H, W = time_steps, heigh, width |
| pos = torch.stack([pos[..., 0] / (W - 1), pos[..., 1] / (H - 1)], dim=-1) |
| pos = pos - 0.5 |
| pos = pos * 0.25 |
| t = 1 - torch.ones_like(pos[..., :1]) * t / (T - 1) |
| pos = torch.cat([pos, t], dim=-1) |
| M = torch.tensor([ |
| [0.8, 0, 0.5], |
| [-0.2, 1.0, 0.1], |
| [0.0, 0.0, 0.0] |
| ]) |
| pos = pos @ M.t().to(pos.device) |
| pos = pos[..., :2] |
| pos[..., 0] += 0.25 |
| pos[..., 1] += 0.45 |
| pos[..., 0] *= (W - 1) |
| pos[..., 1] *= (H - 1) |
| return pos |
|
|
| def draw(pos, vis, col, height, width, radius=1): |
| H, W = height, width |
| frame = torch.zeros(H * W, 4, device=pos.device) |
| pos = pos[vis.bool()] |
| col = col[vis.bool()] |
| if radius > 1: |
| pos, col = get_radius_neighbors(pos, col, radius) |
| else: |
| pos, col = get_cardinal_neighbors(pos, col) |
| inbound = (pos[:, 0] >= 0) & (pos[:, 0] <= W - 1) & (pos[:, 1] >= 0) & (pos[:, 1] <= H - 1) |
| pos = pos[inbound] |
| col = col[inbound] |
| pos = pos.round().long() |
| idx = pos[:, 1] * W + pos[:, 0] |
| idx = idx.view(-1, 1).expand(-1, 4) |
| frame.scatter_add_(0, idx, col) |
| frame = frame.view(H, W, 4) |
| frame, alpha = frame[..., :3], frame[..., 3] |
| nonzero = alpha > 0 |
| frame[nonzero] /= alpha[nonzero][..., None] |
| alpha = nonzero[..., None].float() |
| return frame, alpha |
|
|
| def get_cardinal_neighbors(pos, col, eps=0.01): |
| pos_nw = torch.stack([pos[:, 0].floor(), pos[:, 1].floor()], dim=-1) |
| pos_sw = torch.stack([pos[:, 0].floor(), pos[:, 1].floor() + 1], dim=-1) |
| pos_ne = torch.stack([pos[:, 0].floor() + 1, pos[:, 1].floor()], dim=-1) |
| pos_se = torch.stack([pos[:, 0].floor() + 1, pos[:, 1].floor() + 1], dim=-1) |
| w_n = pos[:, 1].floor() + 1 - pos[:, 1] + eps |
| w_s = pos[:, 1] - pos[:, 1].floor() + eps |
| w_w = pos[:, 0].floor() + 1 - pos[:, 0] + eps |
| w_e = pos[:, 0] - pos[:, 0].floor() + eps |
| w_nw = (w_n * w_w)[:, None] |
| w_sw = (w_s * w_w)[:, None] |
| w_ne = (w_n * w_e)[:, None] |
| w_se = (w_s * w_e)[:, None] |
| col_nw = torch.cat([w_nw * col, w_nw], dim=-1) |
| col_sw = torch.cat([w_sw * col, w_sw], dim=-1) |
| col_ne = torch.cat([w_ne * col, w_ne], dim=-1) |
| col_se = torch.cat([w_se * col, w_se], dim=-1) |
| pos = torch.cat([pos_nw, pos_sw, pos_ne, pos_se], dim=0) |
| col = torch.cat([col_nw, col_sw, col_ne, col_se], dim=0) |
| return pos, col |
|
|
|
|
| def get_radius_neighbors(pos, col, radius): |
| R = math.ceil(radius) |
| center = torch.stack([pos[:, 0].round(), pos[:, 1].round()], dim=-1) |
| nn = torch.arange(-R, R + 1) |
| nn = torch.stack([nn[None, :].expand(2 * R + 1, -1), nn[:, None].expand(-1, 2 * R + 1)], dim=-1) |
| nn = nn.view(-1, 2).cuda() |
| in_radius = nn[:, 0] ** 2 + nn[:, 1] ** 2 <= radius ** 2 |
| nn = nn[in_radius] |
| w = 1 - nn.pow(2).sum(-1).sqrt() / radius + 0.01 |
| w = w[None].expand(pos.size(0), -1).reshape(-1) |
| pos = (center.view(-1, 1, 2) + nn.view(1, -1, 2)).view(-1, 2) |
| col = col.view(-1, 1, 3).repeat(1, nn.size(0), 1) |
| col = col.view(-1, 3) |
| col = torch.cat([col * w[:, None], w[:, None]], dim=-1) |
| return pos, col |
|
|
|
|
| def get_rainbow_colors(size): |
| col_map = colormaps["jet"] |
| col_range = np.array(range(size)) / (size - 1) |
| col = torch.from_numpy(col_map(col_range)[..., :3]).float() |
| col = col.view(-1, 3) |
| return col |
|
|
|
|
| def spline_interpolation(x, length=10): |
| if length != 1: |
| T, N, C = x.shape |
| x = x.view(T, -1).cpu().numpy() |
| original_time = np.arange(T) |
| cs = scipy.interpolate.CubicSpline(original_time, x) |
| new_time = np.linspace(original_time[0], original_time[-1], T * length) |
| x = torch.from_numpy(cs(new_time)).view(-1, N, C).float().cuda() |
| return x |
|
|
| def create_folder(path, verbose=False, exist_ok=True, safe=True): |
| if os.path.exists(path) and not exist_ok: |
| if not safe: |
| raise OSError |
| return False |
| try: |
| os.makedirs(path) |
| except: |
| if not safe: |
| raise OSError |
| return False |
| if verbose: |
| print(f"Created folder: {path}") |
| return True |
|
|
|
|
| def write_video_to_file(video, path, channels): |
| create_folder(os.path.dirname(path)) |
| if channels == "first": |
| video = video.permute(0, 2, 3, 1) |
| video = (video.cpu() * 255.).to(torch.uint8) |
| torchvision.io.write_video(path, video, 8, "h264", options={"pix_fmt": "yuv420p", "crf": "23"}) |
| return video |
|
|
|
|
| def write_frame(frame, path, channels="first"): |
| create_folder(os.path.dirname(path)) |
| frame = frame.cpu().numpy() |
| if channels == "first": |
| frame = np.transpose(frame, (1, 2, 0)) |
| frame = np.clip(np.round(frame * 255), 0, 255).astype(np.uint8) |
| frame = Image.fromarray(frame) |
| frame.save(path) |
|
|
|
|
| def write_video_to_folder(video, path, channels, zero_padded, ext): |
| create_folder(path) |
| time_steps = video.shape[0] |
| for step in range(time_steps): |
| pad = "0" * (len(str(time_steps)) - len(str(step))) if zero_padded else "" |
| frame_path = os.path.join(path, f"{pad}{step}.{ext}") |
| write_frame(video[step], frame_path, channels) |
| |
| |
|
|
| def write_video(video, path, channels="first", zero_padded=True, ext="png", dtype="torch"): |
| if dtype == "numpy": |
| video = torch.from_numpy(video) |
| if path.endswith(".mp4"): |
| write_video_to_file(video, path, channels) |
| else: |
| write_video_to_folder(video, path, channels, zero_padded, ext) |
|
|