phi-noise / video_processing_utils.py
ofirab's picture
Upload folder using huggingface_hub
bee9822 verified
Raw
History Blame Contribute Delete
12.9 kB
import sys
import os
import torch
import cv2
import numpy as np
import imageio
from torchvision.transforms import Grayscale, functional as TF
NUM_FRAMES_5B = 121
NUM_FRAMES_14B = 81
STRIDE = 8
def load_video(path, target_size=(1280, 704),
ret_motion_video=False,
frame_limit=None,
stride=(4, 16, 16)):
cap = cv2.VideoCapture(path)
frames = []
motion_frames = []
i = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret: break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_rgb = cv2.resize(frame_rgb, target_size)
frames.append(frame_rgb)
frame_gs = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
if ret_motion_video:
if i == 0:
prev_frame_gs = frame_gs
motion_frame = cv2.absdiff(frame_gs, prev_frame_gs)
# mask if larger than threshold, binary
motion_frame = cv2.threshold(motion_frame, 10, 255, cv2.THRESH_BINARY)[1]
# resize using nearatse neighbor
motion_frame = cv2.resize(motion_frame, (target_size[0] // stride[-2], target_size[1] // stride[-1]), interpolation=cv2.INTER_NEAREST)
motion_frames.append(motion_frame)
prev_frame_gs = frame_gs
# print(f"Loaded frame {len(frames)} at {cap.get(cv2.CAP_PROP_FPS):.2f} FPS")
i += 1
cap.release()
# Shape: [F, H, W, C] -> [C, F, H, W]
vid = np.array(frames).transpose(3, 0, 1, 2)
vid = (vid / 127.5) - 1.0 # Normalize to [-1, 1]
if frame_limit is not None:
vid = vid[:, :frame_limit]
if ret_motion_video:
motion_frames.append(motion_frame)
motion_mask = np.array(motion_frames)[::STRIDE][None]
# motion_vid = motion_vid / motion_vid.max()
if frame_limit is not None:
motion_mask = motion_mask[:, :frame_limit]
return (torch.from_numpy(vid).float(), torch.from_numpy(motion_mask)) if ret_motion_video else torch.from_numpy(vid).float()
def get_video_fps(path):
cap = cv2.VideoCapture(path)
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
return fps
def load_motion_video(path, target_size=(1280, 704)):
cap = cv2.VideoCapture(path)
prev_frame = cv2.cvtColor(cap.read()[1], cv2.COLOR_BGR2GRAY)
motion_frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret: break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
diff = cv2.absdiff(prev_frame, frame)
diff = cv2.resize(diff, target_size)
diff[diff < 25] = 0
# _, motion_mask = cv2.threshold(diff, 25, 255, cv2.THRESH_BINARY)
motion_frames.append(diff)
prev_frame = frame
# print frame rate
# print(f"Loaded frame {len(frames)} at {cap.get(cv2.CAP_PROP_FPS):.2f} FPS")
cap.release()
# Shape: [F, H, W, C] -> [C, F, H, W]
vid = np.array(motion_frames)[None, ::4, ...] # .repeat(3, axis=0) # [1, F, H, W] -> [3, F, H, W]
# vid = (vid / 127.5) - 1.0 # Normalize to [-1, 1]
return torch.from_numpy(vid).float()
def save_video(frames, path, fps, save_first_frame=False, gs=False):
if isinstance(frames, (list, tuple)):
frames = torch.cat(frames, dim=0)
if isinstance(frames, torch.Tensor):
frames = frames.detach().float().cpu()
# Support [1, 3, F, H, W], [3, F, H, W], or [F, H, W, 3].
if frames.ndim == 5 and frames.shape[0] == 1:
frames = frames[0]
if frames.ndim == 4 and frames.shape[0] == 3:
frames = frames.permute(1, 2, 3, 0) # [3, F, H, W] -> [F, H, W, 3]
# if grayscale, convert to RGB by repeating channels
if frames.ndim == 4 and frames.shape[0] == 1:
frames = frames.repeat(3, 1, 1, 1).permute(1, 2, 3, 0) # [1, F, H, W] -> [F, H, W, 3]
elif frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(f"Unsupported frame tensor shape: {tuple(frames.shape)}")
# VAE output is typically in [-1, 1]. Convert to [0, 255] uint8.
if frames.min() < 0 and frames.max() < 1.1:
frames = ((frames + 1.0) * 127.5).clamp(0, 255)
else:
frames *= 255.0
frames = frames.clamp(0, 255)
frames = frames.to(torch.uint8).numpy()
else:
frames = np.asarray(frames)
if frames.ndim != 4 or frames.shape[-1] != 3:
raise ValueError(f"Expected numpy frames with shape [F, H, W, 3], got {frames.shape}")
if frames.dtype != np.uint8:
frames = np.clip(frames, 0, 255).astype(np.uint8)
if gs:
transform = Grayscale(num_output_channels=3)
frames = transform(torch.from_numpy(frames).float().permute(0, 3, 1, 2)).permute(0, 2, 3, 1).byte().numpy()
# height, width = frames.shape[1], frames.shape[2]
# save first frame
if save_first_frame:
imageio.imwrite(f"{path[:-4]}_ff.png", frames[0])
imageio.mimwrite(path, frames, fps=fps, codec='libx264', quality=8)
def decode_video(latents, save_path="reconstructed_video.mp4", fps=24.0):
with torch.no_grad():
video_recon = vae.decode(latents)
if save_path:
save_video(video_recon, save_path, fps=fps)
return video_recon
def encode_video(video_path, target_size=(1280, 704), frame_limit=None, vae_enc=None, ret_motion_mask=False, stride=(4, 16, 16)):
vae_enc = vae_enc if vae_enc is not None else vae
# target_size
out = load_video(video_path, target_size=target_size, ret_motion_video=ret_motion_mask, frame_limit=frame_limit, stride=stride)
if ret_motion_mask:
video_tensor, motion_mask = out
video_tensor = video_tensor.to(vae_enc.device, dtype=vae_enc.dtype)
motion_mask = motion_mask.to(vae_enc.device, dtype=vae_enc.dtype)
else:
video_tensor = out
input_tensor = video_tensor.unsqueeze(0).to(vae_enc.device, dtype=vae_enc.dtype)
with torch.no_grad():
latents = vae_enc.encode([input_tensor[0]])
if ret_motion_mask:
return latents, motion_mask.bool()
return latents, None # [1, latent_dim, F//16, H//16, W//16]
def noise_up(latents,
x,
noise=None,
num_train_timesteps=1000,
x_is_timestep=True,
clamp_sigma=True,
return_noise=False,
generator=None):
"""
Create WAN flow-matching noisy latent Z_x from clean latent and Gaussian noise.
WAN schedulers use:
alpha_x = 1 - sigma_x
Z_x = alpha_x * z0 + sigma_x * eps
Args:
latents: Tensor shaped [C, F, H, W] or list/tuple with a single tensor.
x: Noise level selector. If x_is_timestep=True, interpreted in [0, num_train_timesteps].
Otherwise interpreted directly as sigma in [0, 1].
noise: Optional epsilon tensor with same shape as latent.
num_train_timesteps: Training timestep count (WAN default: 1000).
x_is_timestep: Whether x is a timestep value rather than direct sigma.
clamp_sigma: Clamp sigma to [0, 1].
return_noise: If True, also return the sampled/used epsilon tensor.
generator: Optional torch.Generator used when sampling noise.
Returns:
Same container type as latents (tensor or single-item list), optionally with noise.
"""
is_sequence = isinstance(latents, (list, tuple))
latent = latents[0] if is_sequence else latents
if x_is_timestep:
sigma = torch.as_tensor(
x, device=latent.device, dtype=torch.float32) / float(num_train_timesteps)
else:
sigma = torch.as_tensor(x, device=latent.device, dtype=torch.float32)
if clamp_sigma:
sigma = sigma.clamp(0.0, 1.0)
while sigma.ndim < latent.ndim:
sigma = sigma.unsqueeze(-1)
alpha = 1.0 - sigma
if noise is None:
noise = torch.randn(
latent.shape,
device=latent.device,
dtype=torch.float32,
generator=generator)
else:
noise = noise.to(device=latent.device, dtype=torch.float32)
zx = alpha * latent.to(torch.float32) + sigma * noise
zx = zx.to(latent.dtype)
if is_sequence:
zx_out = [zx]
else:
zx_out = zx
if return_noise:
return zx_out, noise
return zx_out
def preprocess_guidance(video_path, model_type='5B', fps=None, gs=False):
if model_type == '5B':
num_frames = 121
target_size = (1280, 704)
fps = 24.0
elif model_type == '14B-low':
num_frames= 81 # 121 # -> default is 81
target_size= (832, 464) # (832, 480) -> WRONG!
fps = 16.0
elif model_type == '14B-high':
num_frames= 81
target_size= (1280, 720)
fps = 16.0
fn = video_path.split("/")[-1].split(".")[0]
video_tensor = load_video(video_path, target_size=target_size)
orig_num_frames = video_tensor.shape[1]
new_frame_indices = np.round(np.linspace(0, orig_num_frames - 1, num_frames)).astype(int)
adjusted_video_tensor = video_tensor[:, new_frame_indices]
# adjust fps
new_fps = int(get_video_fps(video_path) / (round(get_video_fps(video_path) / fps))) if fps is None else fps
assert adjusted_video_tensor.shape[1] == num_frames, f"Expected {num_frames} frames after preprocessing, but got {adjusted_video_tensor.shape[1]}"
print(f'removing {orig_num_frames - adjusted_video_tensor.shape[1]} frames')
output_path = f"{os.path.dirname(video_path)}/preprocessed_{model_type}_{num_frames}f{'_gs' if gs else ''}_{fn}.mp4"
save_video(adjusted_video_tensor, output_path, fps=new_fps, save_first_frame=True)
print(f'Saved in {output_path}')
return output_path
def preprocess_guidance_ttm(video_path, model_type='5B'):
if model_type == '5B':
num_frames = 121
target_size = (1280, 704)
fps = 24.0
elif model_type == '14B-low':
num_frames= 81
target_size= (832, 464) # (832, 480) -> WRONG!
fps = 16.0
elif model_type == '14B-high':
num_frames= 81
target_size= (1280, 720)
fps = 16.0
fn = video_path.split("/")[-1].split(".")[0]
video_tensor = load_video(video_path, target_size=target_size)
adjusted_num_frames = video_tensor.shape[1]
skip_frame = int(adjusted_num_frames / num_frames)
print(f'sub-sampleing rate: {skip_frame}')
video_tensor = video_tensor[:, ::skip_frame]
print(f'removing {adjusted_num_frames - video_tensor.shape[1]}')
video_tensor = video_tensor[:, :num_frames]
save_video(video_tensor, f"{os.path.dirname(video_path)}/preprocessed_{model_type}_{fn}_{num_frames}f.mp4", fps=fps, save_first_frame=True, gs=False)
print(['Done.'])
def image_motion_mix(vid_path, img_path, model_type):
assert model_type in vid_path
if model_type == '5B':
target_size = (1280, 704)
fps = 24.0
elif model_type == '14B-low':
target_size= (832, 464) # (832, 480) -> WRONG!
fps = 16.0
elif model_type == '14B-high':
target_size= (1280, 720)
fps = 16.0
cap = cv2.VideoCapture(vid_path)
frames = []
i = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret: break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
frame = cv2.resize(frame, target_size)
frames.append(frame)
frames = np.stack(frames, axis=0)
ref_img = cv2.imread(img_path)
ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
ref_img = cv2.resize(ref_img, target_size)
frames_diff = [ref_img]
for i in range(1, len(frames)):
diff = frames[i-1] - frames[i]
fused_frame = frames_diff[0] + np.stack([diff, diff, diff], axis=-1)
frames_diff.append(fused_frame)
imageio.mimwrite(f'{vid_path[:-4]}_x_{os.path.basename(img_path)[:-4]}_diff.mp4', frames_diff, fps=fps, codec='libx264', quality=8)
print('save', f'{vid_path[:-4]}_x_{os.path.basename(img_path)[:-4]}_diff.mp4')
def decode_latent_img(latent, save_path='./latent.png'):
with torch.no_grad():
img_recon = vae.decode([latent])[0]
# cnvert to scale [0, 255] and save
np.save(save_path.replace('.png', '.npy'), latent.cpu().numpy())
img_recon = (img_recon + 1) / 2 * 255
imageio.imwrite(save_path, img_recon.permute(1, 2, 3, 0)[0].float().byte().cpu())
return img_recon
if __name__ == "__main__":
VID_TO_PROCESS = 'guidance_exmaples/woman_turning.mp4'
# for I2V motion transfer
# IMG_SOURCE = 'guidance_exmaples/i2v-mt/rock.png'
output_path_5b = preprocess_guidance(VID_TO_PROCESS, model_type='5B')
output_path_14bl = preprocess_guidance(VID_TO_PROCESS, model_type='14B-low')
# for motion mix (for moore stable I2V MT)
# image_motion_mix(output_path_14bl, IMG_SOURCE, model_type='14B-low')