| import sys |
| import os |
| import torch |
| import cv2 |
| import numpy as np |
| import imageio |
| from torchvision.transforms import Grayscale, functional as TF |
|
|
| NUM_FRAMES_5B = 121 |
| NUM_FRAMES_14B = 81 |
|
|
| STRIDE = 8 |
|
|
| def load_video(path, target_size=(1280, 704), |
| ret_motion_video=False, |
| frame_limit=None, |
| stride=(4, 16, 16)): |
| |
| cap = cv2.VideoCapture(path) |
| frames = [] |
| motion_frames = [] |
| i = 0 |
| while cap.isOpened(): |
| ret, frame = cap.read() |
| if not ret: break |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frame_rgb = cv2.resize(frame_rgb, target_size) |
| frames.append(frame_rgb) |
|
|
| frame_gs = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
| if ret_motion_video: |
| if i == 0: |
| prev_frame_gs = frame_gs |
| motion_frame = cv2.absdiff(frame_gs, prev_frame_gs) |
| |
| motion_frame = cv2.threshold(motion_frame, 10, 255, cv2.THRESH_BINARY)[1] |
| |
| motion_frame = cv2.resize(motion_frame, (target_size[0] // stride[-2], target_size[1] // stride[-1]), interpolation=cv2.INTER_NEAREST) |
| motion_frames.append(motion_frame) |
| prev_frame_gs = frame_gs |
| |
| i += 1 |
| cap.release() |
| |
| |
| vid = np.array(frames).transpose(3, 0, 1, 2) |
| vid = (vid / 127.5) - 1.0 |
| if frame_limit is not None: |
| vid = vid[:, :frame_limit] |
| if ret_motion_video: |
| motion_frames.append(motion_frame) |
| motion_mask = np.array(motion_frames)[::STRIDE][None] |
| |
| if frame_limit is not None: |
| motion_mask = motion_mask[:, :frame_limit] |
| return (torch.from_numpy(vid).float(), torch.from_numpy(motion_mask)) if ret_motion_video else torch.from_numpy(vid).float() |
|
|
|
|
| def get_video_fps(path): |
| cap = cv2.VideoCapture(path) |
| fps = cap.get(cv2.CAP_PROP_FPS) |
| cap.release() |
| return fps |
|
|
|
|
| def load_motion_video(path, target_size=(1280, 704)): |
| cap = cv2.VideoCapture(path) |
| prev_frame = cv2.cvtColor(cap.read()[1], cv2.COLOR_BGR2GRAY) |
| motion_frames = [] |
| while cap.isOpened(): |
| ret, frame = cap.read() |
| if not ret: break |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
| diff = cv2.absdiff(prev_frame, frame) |
| diff = cv2.resize(diff, target_size) |
| diff[diff < 25] = 0 |
| |
| motion_frames.append(diff) |
| prev_frame = frame |
| |
| |
| cap.release() |
| |
| |
| vid = np.array(motion_frames)[None, ::4, ...] |
| |
| return torch.from_numpy(vid).float() |
|
|
|
|
| def save_video(frames, path, fps, save_first_frame=False, gs=False): |
| if isinstance(frames, (list, tuple)): |
| frames = torch.cat(frames, dim=0) |
|
|
| if isinstance(frames, torch.Tensor): |
| frames = frames.detach().float().cpu() |
|
|
| |
| if frames.ndim == 5 and frames.shape[0] == 1: |
| frames = frames[0] |
| if frames.ndim == 4 and frames.shape[0] == 3: |
| frames = frames.permute(1, 2, 3, 0) |
| |
| if frames.ndim == 4 and frames.shape[0] == 1: |
| frames = frames.repeat(3, 1, 1, 1).permute(1, 2, 3, 0) |
| elif frames.ndim != 4 or frames.shape[-1] != 3: |
| raise ValueError(f"Unsupported frame tensor shape: {tuple(frames.shape)}") |
| |
| |
| if frames.min() < 0 and frames.max() < 1.1: |
| frames = ((frames + 1.0) * 127.5).clamp(0, 255) |
| else: |
| frames *= 255.0 |
| frames = frames.clamp(0, 255) |
| frames = frames.to(torch.uint8).numpy() |
| else: |
| frames = np.asarray(frames) |
| if frames.ndim != 4 or frames.shape[-1] != 3: |
| raise ValueError(f"Expected numpy frames with shape [F, H, W, 3], got {frames.shape}") |
| if frames.dtype != np.uint8: |
| frames = np.clip(frames, 0, 255).astype(np.uint8) |
|
|
| if gs: |
| transform = Grayscale(num_output_channels=3) |
| frames = transform(torch.from_numpy(frames).float().permute(0, 3, 1, 2)).permute(0, 2, 3, 1).byte().numpy() |
| |
| |
| if save_first_frame: |
| imageio.imwrite(f"{path[:-4]}_ff.png", frames[0]) |
| imageio.mimwrite(path, frames, fps=fps, codec='libx264', quality=8) |
|
|
|
|
| def decode_video(latents, save_path="reconstructed_video.mp4", fps=24.0): |
| with torch.no_grad(): |
| video_recon = vae.decode(latents) |
| if save_path: |
| save_video(video_recon, save_path, fps=fps) |
| return video_recon |
|
|
|
|
| def encode_video(video_path, target_size=(1280, 704), frame_limit=None, vae_enc=None, ret_motion_mask=False, stride=(4, 16, 16)): |
| vae_enc = vae_enc if vae_enc is not None else vae |
| |
| out = load_video(video_path, target_size=target_size, ret_motion_video=ret_motion_mask, frame_limit=frame_limit, stride=stride) |
| if ret_motion_mask: |
| video_tensor, motion_mask = out |
| video_tensor = video_tensor.to(vae_enc.device, dtype=vae_enc.dtype) |
| motion_mask = motion_mask.to(vae_enc.device, dtype=vae_enc.dtype) |
| else: |
| video_tensor = out |
| input_tensor = video_tensor.unsqueeze(0).to(vae_enc.device, dtype=vae_enc.dtype) |
| with torch.no_grad(): |
| latents = vae_enc.encode([input_tensor[0]]) |
| if ret_motion_mask: |
| return latents, motion_mask.bool() |
| return latents, None |
|
|
|
|
| def noise_up(latents, |
| x, |
| noise=None, |
| num_train_timesteps=1000, |
| x_is_timestep=True, |
| clamp_sigma=True, |
| return_noise=False, |
| generator=None): |
| """ |
| Create WAN flow-matching noisy latent Z_x from clean latent and Gaussian noise. |
| |
| WAN schedulers use: |
| alpha_x = 1 - sigma_x |
| Z_x = alpha_x * z0 + sigma_x * eps |
| |
| Args: |
| latents: Tensor shaped [C, F, H, W] or list/tuple with a single tensor. |
| x: Noise level selector. If x_is_timestep=True, interpreted in [0, num_train_timesteps]. |
| Otherwise interpreted directly as sigma in [0, 1]. |
| noise: Optional epsilon tensor with same shape as latent. |
| num_train_timesteps: Training timestep count (WAN default: 1000). |
| x_is_timestep: Whether x is a timestep value rather than direct sigma. |
| clamp_sigma: Clamp sigma to [0, 1]. |
| return_noise: If True, also return the sampled/used epsilon tensor. |
| generator: Optional torch.Generator used when sampling noise. |
| |
| Returns: |
| Same container type as latents (tensor or single-item list), optionally with noise. |
| """ |
| is_sequence = isinstance(latents, (list, tuple)) |
| latent = latents[0] if is_sequence else latents |
|
|
| if x_is_timestep: |
| sigma = torch.as_tensor( |
| x, device=latent.device, dtype=torch.float32) / float(num_train_timesteps) |
| else: |
| sigma = torch.as_tensor(x, device=latent.device, dtype=torch.float32) |
|
|
| if clamp_sigma: |
| sigma = sigma.clamp(0.0, 1.0) |
|
|
| while sigma.ndim < latent.ndim: |
| sigma = sigma.unsqueeze(-1) |
|
|
| alpha = 1.0 - sigma |
|
|
| if noise is None: |
| noise = torch.randn( |
| latent.shape, |
| device=latent.device, |
| dtype=torch.float32, |
| generator=generator) |
| else: |
| noise = noise.to(device=latent.device, dtype=torch.float32) |
|
|
| zx = alpha * latent.to(torch.float32) + sigma * noise |
| zx = zx.to(latent.dtype) |
|
|
| if is_sequence: |
| zx_out = [zx] |
| else: |
| zx_out = zx |
|
|
| if return_noise: |
| return zx_out, noise |
| return zx_out |
|
|
|
|
| def preprocess_guidance(video_path, model_type='5B', fps=None, gs=False): |
| if model_type == '5B': |
| num_frames = 121 |
| target_size = (1280, 704) |
| fps = 24.0 |
| elif model_type == '14B-low': |
| num_frames= 81 |
| target_size= (832, 464) |
| fps = 16.0 |
| elif model_type == '14B-high': |
| num_frames= 81 |
| target_size= (1280, 720) |
| fps = 16.0 |
| fn = video_path.split("/")[-1].split(".")[0] |
| video_tensor = load_video(video_path, target_size=target_size) |
| orig_num_frames = video_tensor.shape[1] |
| new_frame_indices = np.round(np.linspace(0, orig_num_frames - 1, num_frames)).astype(int) |
| adjusted_video_tensor = video_tensor[:, new_frame_indices] |
| |
| new_fps = int(get_video_fps(video_path) / (round(get_video_fps(video_path) / fps))) if fps is None else fps |
| assert adjusted_video_tensor.shape[1] == num_frames, f"Expected {num_frames} frames after preprocessing, but got {adjusted_video_tensor.shape[1]}" |
| print(f'removing {orig_num_frames - adjusted_video_tensor.shape[1]} frames') |
| output_path = f"{os.path.dirname(video_path)}/preprocessed_{model_type}_{num_frames}f{'_gs' if gs else ''}_{fn}.mp4" |
| save_video(adjusted_video_tensor, output_path, fps=new_fps, save_first_frame=True) |
| print(f'Saved in {output_path}') |
| return output_path |
|
|
|
|
| def preprocess_guidance_ttm(video_path, model_type='5B'): |
| if model_type == '5B': |
| num_frames = 121 |
| target_size = (1280, 704) |
| fps = 24.0 |
| elif model_type == '14B-low': |
| num_frames= 81 |
| target_size= (832, 464) |
| fps = 16.0 |
| elif model_type == '14B-high': |
| num_frames= 81 |
| target_size= (1280, 720) |
| fps = 16.0 |
| fn = video_path.split("/")[-1].split(".")[0] |
| video_tensor = load_video(video_path, target_size=target_size) |
| adjusted_num_frames = video_tensor.shape[1] |
| skip_frame = int(adjusted_num_frames / num_frames) |
| print(f'sub-sampleing rate: {skip_frame}') |
| video_tensor = video_tensor[:, ::skip_frame] |
| print(f'removing {adjusted_num_frames - video_tensor.shape[1]}') |
| video_tensor = video_tensor[:, :num_frames] |
| save_video(video_tensor, f"{os.path.dirname(video_path)}/preprocessed_{model_type}_{fn}_{num_frames}f.mp4", fps=fps, save_first_frame=True, gs=False) |
| print(['Done.']) |
|
|
|
|
| def image_motion_mix(vid_path, img_path, model_type): |
| assert model_type in vid_path |
| if model_type == '5B': |
| target_size = (1280, 704) |
| fps = 24.0 |
| elif model_type == '14B-low': |
| target_size= (832, 464) |
| fps = 16.0 |
| elif model_type == '14B-high': |
| target_size= (1280, 720) |
| fps = 16.0 |
|
|
| cap = cv2.VideoCapture(vid_path) |
| frames = [] |
| i = 0 |
| while cap.isOpened(): |
| ret, frame = cap.read() |
| if not ret: break |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
| frame = cv2.resize(frame, target_size) |
| frames.append(frame) |
|
|
| frames = np.stack(frames, axis=0) |
| ref_img = cv2.imread(img_path) |
| ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB) |
| ref_img = cv2.resize(ref_img, target_size) |
|
|
| frames_diff = [ref_img] |
| for i in range(1, len(frames)): |
| diff = frames[i-1] - frames[i] |
| fused_frame = frames_diff[0] + np.stack([diff, diff, diff], axis=-1) |
| frames_diff.append(fused_frame) |
| imageio.mimwrite(f'{vid_path[:-4]}_x_{os.path.basename(img_path)[:-4]}_diff.mp4', frames_diff, fps=fps, codec='libx264', quality=8) |
| print('save', f'{vid_path[:-4]}_x_{os.path.basename(img_path)[:-4]}_diff.mp4') |
|
|
|
|
| def decode_latent_img(latent, save_path='./latent.png'): |
| with torch.no_grad(): |
| img_recon = vae.decode([latent])[0] |
| |
| np.save(save_path.replace('.png', '.npy'), latent.cpu().numpy()) |
| img_recon = (img_recon + 1) / 2 * 255 |
| imageio.imwrite(save_path, img_recon.permute(1, 2, 3, 0)[0].float().byte().cpu()) |
| return img_recon |
|
|
|
|
| if __name__ == "__main__": |
| VID_TO_PROCESS = 'guidance_exmaples/woman_turning.mp4' |
| |
| |
| |
| output_path_5b = preprocess_guidance(VID_TO_PROCESS, model_type='5B') |
| output_path_14bl = preprocess_guidance(VID_TO_PROCESS, model_type='14B-low') |
|
|
| |
| |
|
|