Image-to-Video
zzwustc's picture
Upload folder using huggingface_hub
ef296aa verified
raw
history blame
6.97 kB
import os
import argparse
from PIL import Image
from glob import glob
import numpy as np
import json
import torch
import torchvision
from torch.nn import functional as F
from matplotlib import colormaps
import math
import scipy
def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True):
H, W = height, width
S = shape if shape else []
if align_corners:
x = torch.linspace(0, 1, W, device=device)
y = torch.linspace(0, 1, H, device=device)
if not normalize:
x = x * (W - 1)
y = y * (H - 1)
else:
x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device)
y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device)
if not normalize:
x = x * W
y = y * H
x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W]
x = x.view(*x_view).expand(*exp)
y = y.view(*y_view).expand(*exp)
grid = torch.stack([x, y], dim=-1)
if dtype == "numpy":
grid = grid.numpy()
return grid
def translation(frame, dx, dy, pad_value):
C, H, W = frame.shape
grid = get_grid(H, W, device=frame.device)
grid[..., 0] = grid[..., 0] - (dx / (W - 1))
grid[..., 1] = grid[..., 1] - (dy / (H - 1))
frame = frame - pad_value
frame = torch.nn.functional.grid_sample(frame[None], grid[None] * 2 - 1, mode='bilinear', align_corners=True)[0]
frame = frame + pad_value
return frame
def project(pos, t, time_steps, heigh, width):
T, H, W = time_steps, heigh, width
pos = torch.stack([pos[..., 0] / (W - 1), pos[..., 1] / (H - 1)], dim=-1)
pos = pos - 0.5
pos = pos * 0.25
t = 1 - torch.ones_like(pos[..., :1]) * t / (T - 1)
pos = torch.cat([pos, t], dim=-1)
M = torch.tensor([
[0.8, 0, 0.5],
[-0.2, 1.0, 0.1],
[0.0, 0.0, 0.0]
])
pos = pos @ M.t().to(pos.device)
pos = pos[..., :2]
pos[..., 0] += 0.25
pos[..., 1] += 0.45
pos[..., 0] *= (W - 1)
pos[..., 1] *= (H - 1)
return pos
def draw(pos, vis, col, height, width, radius=1):
H, W = height, width
frame = torch.zeros(H * W, 4, device=pos.device)
pos = pos[vis.bool()]
col = col[vis.bool()]
if radius > 1:
pos, col = get_radius_neighbors(pos, col, radius)
else:
pos, col = get_cardinal_neighbors(pos, col)
inbound = (pos[:, 0] >= 0) & (pos[:, 0] <= W - 1) & (pos[:, 1] >= 0) & (pos[:, 1] <= H - 1)
pos = pos[inbound]
col = col[inbound]
pos = pos.round().long()
idx = pos[:, 1] * W + pos[:, 0]
idx = idx.view(-1, 1).expand(-1, 4)
frame.scatter_add_(0, idx, col)
frame = frame.view(H, W, 4)
frame, alpha = frame[..., :3], frame[..., 3]
nonzero = alpha > 0
frame[nonzero] /= alpha[nonzero][..., None]
alpha = nonzero[..., None].float()
return frame, alpha
def get_cardinal_neighbors(pos, col, eps=0.01):
pos_nw = torch.stack([pos[:, 0].floor(), pos[:, 1].floor()], dim=-1)
pos_sw = torch.stack([pos[:, 0].floor(), pos[:, 1].floor() + 1], dim=-1)
pos_ne = torch.stack([pos[:, 0].floor() + 1, pos[:, 1].floor()], dim=-1)
pos_se = torch.stack([pos[:, 0].floor() + 1, pos[:, 1].floor() + 1], dim=-1)
w_n = pos[:, 1].floor() + 1 - pos[:, 1] + eps
w_s = pos[:, 1] - pos[:, 1].floor() + eps
w_w = pos[:, 0].floor() + 1 - pos[:, 0] + eps
w_e = pos[:, 0] - pos[:, 0].floor() + eps
w_nw = (w_n * w_w)[:, None]
w_sw = (w_s * w_w)[:, None]
w_ne = (w_n * w_e)[:, None]
w_se = (w_s * w_e)[:, None]
col_nw = torch.cat([w_nw * col, w_nw], dim=-1)
col_sw = torch.cat([w_sw * col, w_sw], dim=-1)
col_ne = torch.cat([w_ne * col, w_ne], dim=-1)
col_se = torch.cat([w_se * col, w_se], dim=-1)
pos = torch.cat([pos_nw, pos_sw, pos_ne, pos_se], dim=0)
col = torch.cat([col_nw, col_sw, col_ne, col_se], dim=0)
return pos, col
def get_radius_neighbors(pos, col, radius):
R = math.ceil(radius)
center = torch.stack([pos[:, 0].round(), pos[:, 1].round()], dim=-1)
nn = torch.arange(-R, R + 1)
nn = torch.stack([nn[None, :].expand(2 * R + 1, -1), nn[:, None].expand(-1, 2 * R + 1)], dim=-1)
nn = nn.view(-1, 2).cuda()
in_radius = nn[:, 0] ** 2 + nn[:, 1] ** 2 <= radius ** 2
nn = nn[in_radius]
w = 1 - nn.pow(2).sum(-1).sqrt() / radius + 0.01
w = w[None].expand(pos.size(0), -1).reshape(-1)
pos = (center.view(-1, 1, 2) + nn.view(1, -1, 2)).view(-1, 2)
col = col.view(-1, 1, 3).repeat(1, nn.size(0), 1)
col = col.view(-1, 3)
col = torch.cat([col * w[:, None], w[:, None]], dim=-1)
return pos, col
def get_rainbow_colors(size):
col_map = colormaps["jet"]
col_range = np.array(range(size)) / (size - 1)
col = torch.from_numpy(col_map(col_range)[..., :3]).float()
col = col.view(-1, 3)
return col
def spline_interpolation(x, length=10):
if length != 1:
T, N, C = x.shape
x = x.view(T, -1).cpu().numpy()
original_time = np.arange(T)
cs = scipy.interpolate.CubicSpline(original_time, x)
new_time = np.linspace(original_time[0], original_time[-1], T * length)
x = torch.from_numpy(cs(new_time)).view(-1, N, C).float().cuda()
return x
def create_folder(path, verbose=False, exist_ok=True, safe=True):
if os.path.exists(path) and not exist_ok:
if not safe:
raise OSError
return False
try:
os.makedirs(path)
except:
if not safe:
raise OSError
return False
if verbose:
print(f"Created folder: {path}")
return True
def write_video_to_file(video, path, channels):
create_folder(os.path.dirname(path))
if channels == "first":
video = video.permute(0, 2, 3, 1)
video = (video.cpu() * 255.).to(torch.uint8)
torchvision.io.write_video(path, video, 8, "h264", options={"pix_fmt": "yuv420p", "crf": "23"})
return video
def write_frame(frame, path, channels="first"):
create_folder(os.path.dirname(path))
frame = frame.cpu().numpy()
if channels == "first":
frame = np.transpose(frame, (1, 2, 0))
frame = np.clip(np.round(frame * 255), 0, 255).astype(np.uint8)
frame = Image.fromarray(frame)
frame.save(path)
def write_video_to_folder(video, path, channels, zero_padded, ext):
create_folder(path)
time_steps = video.shape[0]
for step in range(time_steps):
pad = "0" * (len(str(time_steps)) - len(str(step))) if zero_padded else ""
frame_path = os.path.join(path, f"{pad}{step}.{ext}")
write_frame(video[step], frame_path, channels)
def write_video(video, path, channels="first", zero_padded=True, ext="png", dtype="torch"):
if dtype == "numpy":
video = torch.from_numpy(video)
if path.endswith(".mp4"):
write_video_to_file(video, path, channels)
else:
write_video_to_folder(video, path, channels, zero_padded, ext)