|
|
import os |
|
|
import sys |
|
|
sys.path.append(os.getcwd()) |
|
|
|
|
|
import argparse |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
import time |
|
|
from contextlib import nullcontext |
|
|
from omegaconf import OmegaConf |
|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import decord |
|
|
from einops import rearrange |
|
|
from lightning.pytorch import seed_everything |
|
|
from torch import autocast |
|
|
from torchvision import transforms |
|
|
from torchvision.io import write_video |
|
|
|
|
|
from vidtok.modules.util import print0 |
|
|
from scripts.inference_evaluate import load_model_from_config |
|
|
|
|
|
|
|
|
class SingleVideoDataset(torch.utils.data.Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
video_path, |
|
|
input_height=128, |
|
|
input_width=128, |
|
|
sample_fps=8, |
|
|
chunk_size=16, |
|
|
is_causal=True, |
|
|
read_long_video=False |
|
|
): |
|
|
decord.bridge.set_bridge("torch") |
|
|
self.video_path = video_path |
|
|
self.transform = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(input_height, antialias=True), |
|
|
transforms.CenterCrop((input_height, input_width)), |
|
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), |
|
|
] |
|
|
) |
|
|
|
|
|
self.video_reader = decord.VideoReader(video_path, num_threads=0) |
|
|
total_frames = len(self.video_reader) |
|
|
fps = self.video_reader.get_avg_fps() |
|
|
|
|
|
interval = round(fps / sample_fps) |
|
|
frame_ids = list(range(0, total_frames, interval)) |
|
|
self.frame_ids_batch = [] |
|
|
if read_long_video: |
|
|
video_length = len(frame_ids) |
|
|
if is_causal and video_length > chunk_size: |
|
|
self.frame_ids_batch.append(frame_ids[:chunk_size * ((video_length - 1) // chunk_size) + 1]) |
|
|
elif not is_causal and video_length >= chunk_size: |
|
|
self.frame_ids_batch.append(frame_ids[:chunk_size * (video_length // chunk_size)]) |
|
|
else: |
|
|
num_frames_per_batch = chunk_size + 1 if is_causal else chunk_size |
|
|
for x in range(0, len(frame_ids), num_frames_per_batch): |
|
|
if len(frame_ids[x : x + num_frames_per_batch]) == num_frames_per_batch: |
|
|
self.frame_ids_batch.append(frame_ids[x : x + num_frames_per_batch]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.frame_ids_batch) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
frame_ids = self.frame_ids_batch[idx] |
|
|
frames = self.video_reader.get_batch(frame_ids).permute(0, 3, 1, 2).float() / 255.0 |
|
|
frames = self.transform(frames).permute(1, 0, 2, 3) |
|
|
return frames |
|
|
|
|
|
|
|
|
def tensor_to_uint8(tensor): |
|
|
tensor = torch.clamp(tensor, -1.0, 1.0) |
|
|
tensor = (tensor + 1.0) / 2.0 |
|
|
tensor = (tensor.cpu().numpy() * 255).astype(np.uint8) |
|
|
return tensor |
|
|
|
|
|
|
|
|
def main(): |
|
|
def str2bool(v): |
|
|
if isinstance(v, bool): |
|
|
return v |
|
|
if v.lower() in ("yes", "true", "t", "y", "1"): |
|
|
return True |
|
|
elif v.lower() in ("no", "false", "f", "n", "0"): |
|
|
return False |
|
|
else: |
|
|
raise argparse.ArgumentTypeError("Boolean value expected.") |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=42, |
|
|
help="the seed (for reproducible sampling)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="full" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--config", |
|
|
type=str, |
|
|
default="configs/vidtok_kl_causal_488_4chn.yaml", |
|
|
help="path to config which constructs model", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ckpt", |
|
|
type=str, |
|
|
default="checkpoints/vidtok_kl_causal_488_4chn.ckpt", |
|
|
help="path to checkpoint of model", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_video_dir", |
|
|
type=str, |
|
|
default="tmp", |
|
|
help="path to save the outputs", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--input_video_path", |
|
|
type=str, |
|
|
default="assets/example.mp4", |
|
|
help="path to the input video", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--input_height", |
|
|
type=int, |
|
|
default=256, |
|
|
help="height of the input video", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--input_width", |
|
|
type=int, |
|
|
default=256, |
|
|
help="width of the input video", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--sample_fps", |
|
|
type=int, |
|
|
default=30, |
|
|
help="sample fps", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--chunk_size", |
|
|
type=int, |
|
|
default=16, |
|
|
help="the size of a chunk - we split a long video into several chunks", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--read_long_video", |
|
|
action='store_true' |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pad_gen_frames", |
|
|
action="store_true", |
|
|
help="Used only in causal mode. If True, pad frames generated in the last batch, else replicate the first frame instead", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--concate_input", |
|
|
type=str2bool, |
|
|
const=True, |
|
|
default=True, |
|
|
nargs="?", |
|
|
help="", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
seed_everything(args.seed) |
|
|
|
|
|
print0(f"[bold red]\[scripts.inference_reconstruct][/bold red] Evaluating model {args.ckpt}") |
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
precision_scope = autocast if args.precision == "autocast" else nullcontext |
|
|
config = OmegaConf.load(args.config) |
|
|
|
|
|
os.makedirs(args.output_video_dir, exist_ok=True) |
|
|
|
|
|
model = load_model_from_config(args.config, args.ckpt) |
|
|
model.to(device).eval() |
|
|
assert args.chunk_size % model.encoder.time_downsample_factor == 0 |
|
|
|
|
|
if args.read_long_video: |
|
|
assert hasattr(model, 'use_tiling'), "Tiling inference is needed to conduct long video reconstruction." |
|
|
print(f"Using tiling inference to save memory usage...") |
|
|
model.use_tiling = True |
|
|
model.t_chunk_enc = args.chunk_size |
|
|
model.t_chunk_dec = model.t_chunk_enc // model.encoder.time_downsample_factor |
|
|
model.use_overlap = True |
|
|
|
|
|
dataset = SingleVideoDataset( |
|
|
video_path=args.input_video_path, |
|
|
input_height=args.input_height, |
|
|
input_width=args.input_width, |
|
|
sample_fps=args.sample_fps, |
|
|
chunk_size=args.chunk_size, |
|
|
is_causal=model.is_causal, |
|
|
read_long_video=args.read_long_video |
|
|
) |
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) |
|
|
|
|
|
inputs = [] |
|
|
outputs = [] |
|
|
with torch.no_grad(), precision_scope("cuda"): |
|
|
tic = time.time() |
|
|
for i, input in tqdm(enumerate(dataloader)): |
|
|
input = input.to(device) |
|
|
|
|
|
if model.is_causal and not args.read_long_video and args.pad_gen_frames: |
|
|
if i == 0: |
|
|
_, xrec, _ = model(input) |
|
|
else: |
|
|
_, xrec, _ = model(torch.cat([last_gen_frames, input], dim=2)) |
|
|
xrec = xrec[:, :, -input.shape[2]:].clamp(-1, 1) |
|
|
last_gen_frames = xrec[:, :, (1 - model.encoder.time_downsample_factor):, :, :] |
|
|
else: |
|
|
_, xrec, _ = model(input) |
|
|
|
|
|
input = rearrange(input, "b c t h w -> (b t) c h w") |
|
|
inputs.append(input) |
|
|
xrec = rearrange(xrec.clamp(-1, 1), "b c t h w -> (b t) c h w") |
|
|
outputs.append(xrec) |
|
|
|
|
|
toc = time.time() |
|
|
|
|
|
|
|
|
inputs = tensor_to_uint8(torch.cat(inputs, dim=0)) |
|
|
inputs = rearrange(inputs, "t c h w -> t h w c") |
|
|
outputs = tensor_to_uint8(torch.cat(outputs, dim=0)) |
|
|
outputs = rearrange(outputs, "t c h w -> t h w c") |
|
|
min_len = min(inputs.shape[0], outputs.shape[0]) |
|
|
final = np.concatenate([inputs[:min_len], outputs[:min_len]], axis=2) if args.concate_input else outputs[:min_len] |
|
|
|
|
|
output_video_path = os.path.join(args.output_video_dir, f"{Path(args.input_video_path).stem}_reconstructed.mp4") |
|
|
write_video(output_video_path, final, args.sample_fps) |
|
|
|
|
|
print0(f"[bold red]Results saved in: {output_video_path}[/bold red]") |
|
|
print0(f"[bold red]\[scripts.inference_reconstruct][/bold red] Time taken: {toc - tic:.2f}s") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|