43 / Meissonic /VidTok /test_vidtok.py
BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
#!/usr/bin/env python3
"""
Test script for VidTok tokenizer performance.
This script:
1. Loads a video from the training dataset
2. Encodes it using VidTok tokenizer
3. Decodes it back
4. Computes metrics (PSNR, SSIM, MSE)
5. Creates a side-by-side comparison video
6. Saves the results
Based on VidTok: https://github.com/microsoft/VidTok
"""
import argparse
import os
import sys
sys.path.append(os.getcwd())
import torch
import numpy as np
from PIL import Image
import cv2
from torchvision import transforms
from torchvision.utils import make_grid, save_image
# VidTok imports - adjust path if needed
VIDTOK_AVAILABLE = False
VIDTOK_PATH = None
def _setup_vidtok():
"""Setup VidTok by trying to import or download from GitHub."""
global VIDTOK_AVAILABLE, VIDTOK_PATH
# Try to import from existing installation
try:
from scripts.inference_evaluate import load_model_from_config
VIDTOK_AVAILABLE = True
return load_model_from_config
except ImportError:
pass
# Try to find VidTok in common locations
vidtok_paths = [
"VidTok",
"../VidTok",
os.path.join(os.path.dirname(__file__), "../VidTok"),
os.path.expanduser("~/VidTok"),
]
for vidtok_path in vidtok_paths:
if os.path.exists(vidtok_path) and os.path.exists(os.path.join(vidtok_path, "scripts")):
sys.path.insert(0, vidtok_path)
try:
from scripts.inference_evaluate import load_model_from_config
VIDTOK_AVAILABLE = True
VIDTOK_PATH = vidtok_path
print(f"Found VidTok at: {vidtok_path}")
return load_model_from_config
except ImportError:
if vidtok_path in sys.path:
sys.path.remove(vidtok_path)
continue
# Try to download from GitHub
print("VidTok not found locally. Attempting to download from GitHub...")
try:
import subprocess
import tempfile
# Create cache directory
cache_dir = os.path.join(os.getcwd(), "vidtok_cache")
vidtok_dir = os.path.join(cache_dir, "VidTok")
# Check if already downloaded
if os.path.exists(vidtok_dir) and os.path.exists(os.path.join(vidtok_dir, "scripts")):
sys.path.insert(0, vidtok_dir)
try:
from scripts.inference_evaluate import load_model_from_config
VIDTOK_AVAILABLE = True
VIDTOK_PATH = vidtok_dir
print(f"Using cached VidTok from: {vidtok_dir}")
return load_model_from_config
except ImportError:
if vidtok_dir in sys.path:
sys.path.remove(vidtok_dir)
# Download from GitHub
print("Downloading VidTok from GitHub...")
os.makedirs(cache_dir, exist_ok=True)
# Use git clone if available, otherwise download zip
if subprocess.run(["which", "git"], capture_output=True).returncode == 0:
# Clone repository
if os.path.exists(vidtok_dir):
import shutil
shutil.rmtree(vidtok_dir)
result = subprocess.run(
["git", "clone", "--depth", "1", "https://github.com/microsoft/VidTok.git", vidtok_dir],
capture_output=True,
text=True
)
if result.returncode == 0:
sys.path.insert(0, vidtok_dir)
try:
from scripts.inference_evaluate import load_model_from_config
VIDTOK_AVAILABLE = True
VIDTOK_PATH = vidtok_dir
print(f"Successfully downloaded VidTok to: {vidtok_dir}")
return load_model_from_config
except ImportError as e:
if vidtok_dir in sys.path:
sys.path.remove(vidtok_dir)
print(f"Failed to import VidTok after download: {e}")
else:
# Fallback: download zip file
print("Git not available, trying to download zip file...")
import urllib.request
import zipfile
zip_url = "https://github.com/microsoft/VidTok/archive/refs/heads/main.zip"
zip_path = os.path.join(cache_dir, "VidTok-main.zip")
print(f"Downloading {zip_url}...")
urllib.request.urlretrieve(zip_url, zip_path)
# Extract zip
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(cache_dir)
# Rename extracted directory
extracted_dir = os.path.join(cache_dir, "VidTok-main")
if os.path.exists(vidtok_dir):
import shutil
shutil.rmtree(vidtok_dir)
os.rename(extracted_dir, vidtok_dir)
sys.path.insert(0, vidtok_dir)
try:
from scripts.inference_evaluate import load_model_from_config
VIDTOK_AVAILABLE = True
VIDTOK_PATH = vidtok_dir
print(f"Successfully downloaded VidTok to: {vidtok_dir}")
return load_model_from_config
except ImportError as e:
if vidtok_dir in sys.path:
sys.path.remove(vidtok_dir)
print(f"Failed to import VidTok after download: {e}")
except Exception as e:
print(f"Failed to download VidTok: {e}")
return None
# Setup VidTok
load_model_from_config = _setup_vidtok()
sys.path.append("/mnt/Meissonic")
from train.dataset_utils import OpenVid1MDataset
from transformers import T5Tokenizer
def calculate_psnr(img1, img2, max_val=1.0):
"""Calculate PSNR between two images."""
# Ensure both tensors are on the same device (preferably CPU for metric calculation)
if img1.device != img2.device:
img1 = img1.to(img2.device)
# Move to CPU for metric calculation to avoid GPU memory issues
img1 = img1.cpu()
img2 = img2.cpu()
mse = torch.mean((img1 - img2) ** 2)
if mse == 0:
return float('inf')
psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
return psnr.item()
def calculate_mse(img1, img2):
"""Calculate MSE between two images."""
# Ensure both tensors are on the same device (preferably CPU for metric calculation)
if img1.device != img2.device:
img1 = img1.to(img2.device)
# Move to CPU for metric calculation to avoid GPU memory issues
img1 = img1.cpu()
img2 = img2.cpu()
return torch.mean((img1 - img2) ** 2).item()
def calculate_ssim(img1, img2, window_size=11):
"""Calculate SSIM between two images (simplified version)."""
# Ensure both tensors are on the same device (preferably CPU for metric calculation)
if img1.device != img2.device:
img1 = img1.to(img2.device)
# Move to CPU for metric calculation to avoid GPU memory issues
img1 = img1.cpu()
img2 = img2.cpu()
# Simple SSIM approximation
C1 = 0.01 ** 2
C2 = 0.03 ** 2
mu1 = img1.mean()
mu2 = img2.mean()
sigma1_sq = img1.var()
sigma2_sq = img2.var()
sigma12 = ((img1 - mu1) * (img2 - mu2)).mean()
ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim.item()
def video_to_numpy(video_tensor):
"""
Convert video tensor [C, F, H, W] in [-1, 1] or [0, 1] to numpy array [F, H, W, C] in [0, 255] (RGB).
VidTok uses [-1, 1] range.
"""
if isinstance(video_tensor, torch.Tensor):
# [C, F, H, W] -> [F, C, H, W] -> [F, H, W, C]
video_np = video_tensor.permute(1, 0, 2, 3).cpu().numpy() # [F, C, H, W]
video_np = np.transpose(video_np, (0, 2, 3, 1)) # [F, H, W, C]
# Normalize to [0, 1] range (VidTok uses [-1, 1])
if video_np.min() < 0:
# Assume range is [-1, 1]
video_np = (video_np + 1) / 2
else:
# Assume range is [0, 1]
video_np = np.clip(video_np, 0, 1)
# Convert to [0, 255]
video_np = (video_np * 255).astype(np.uint8)
else:
video_np = np.array(video_tensor)
return video_np
def add_text_to_image(image_tensor, text, position=(10, 30)):
"""
Add text label to an image tensor.
Args:
image_tensor: Image tensor [C, H, W] in [0, 1] or [-1, 1]
text: Text to add
position: (x, y) position for text
Returns:
Image tensor with text [C, H, W]
"""
# Normalize to [0, 1] if needed
img_norm = image_tensor.clone()
if img_norm.min() < 0:
img_norm = (img_norm + 1) / 2
img_norm = torch.clamp(img_norm, 0, 1)
# Convert to PIL Image
image_np = img_norm.permute(1, 2, 0).cpu().numpy() # [H, W, C]
image_np = (image_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_np)
# Add text
from PIL import ImageDraw, ImageFont
draw = ImageDraw.Draw(pil_image)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 24)
except:
try:
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 24)
except:
font = ImageFont.load_default()
# Draw white text with black outline
x, y = position
# Draw outline
for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]:
draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0))
# Draw main text
draw.text((x, y), text, font=font, fill=(255, 255, 255))
# Convert back to tensor (normalize to [0, 1])
image_tensor = transforms.ToTensor()(pil_image)
return image_tensor
def create_side_by_side_video(original, reconstructed, output_path, fps=8):
"""
Create a side-by-side comparison video.
Args:
original: Original video tensor [C, F, H, W] in [0, 1] or [-1, 1]
reconstructed: Reconstructed video tensor [C, F, H, W] in [0, 1] or [-1, 1]
output_path: Path to save the output video
fps: Frames per second
"""
# Convert to numpy (RGB format: [F, H, W, C])
orig_np = video_to_numpy(original)
recon_np = video_to_numpy(reconstructed)
# Get dimensions
F, H, W, C = orig_np.shape
F_recon, H_recon, W_recon, C_recon = recon_np.shape
# Ensure same number of frames
F_min = min(F, F_recon)
orig_np = orig_np[:F_min]
recon_np = recon_np[:F_min]
# Resize if dimensions don't match
if H != H_recon or W != W_recon:
print(f"Resizing reconstructed video from ({H_recon}, {W_recon}) to ({H}, {W})")
recon_np_resized = np.zeros((F_min, H, W, C), dtype=np.uint8)
for f in range(F_min):
recon_np_resized[f] = cv2.resize(recon_np[f], (W, H), interpolation=cv2.INTER_LINEAR)
recon_np = recon_np_resized
# Add text labels to frames
from PIL import Image, ImageDraw, ImageFont
side_by_side_frames = []
for f in range(F_min):
# Original frame with label
orig_frame_pil = Image.fromarray(orig_np[f])
draw = ImageDraw.Draw(orig_frame_pil)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 32)
except:
try:
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 32)
except:
font = ImageFont.load_default()
# Draw text with outline for visibility
text = "Original"
x, y = 20, 20
for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]:
draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0))
draw.text((x, y), text, font=font, fill=(255, 255, 255))
orig_frame = np.array(orig_frame_pil)
# Reconstructed frame with label
recon_frame_pil = Image.fromarray(recon_np[f])
draw = ImageDraw.Draw(recon_frame_pil)
text = "Reconstructed"
x, y = 20, 20
for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]:
draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0))
draw.text((x, y), text, font=font, fill=(255, 255, 0)) # Yellow text
recon_frame = np.array(recon_frame_pil)
# Concatenate horizontally
frame = np.concatenate([orig_frame, recon_frame], axis=1)
side_by_side_frames.append(frame)
# Write video using OpenCV (needs BGR format)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (W * 2, H))
if not out.isOpened():
print(f"Warning: Could not open video writer with mp4v codec, trying XVID...")
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(output_path, fourcc, fps, (W * 2, H))
for frame in side_by_side_frames:
# Convert RGB to BGR for OpenCV
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
out.release()
print(f"Saved side-by-side video to: {output_path}")
def create_comparison_grid(original, reconstructed, output_path, nrow=4):
"""
Create a grid image comparing original and reconstructed frames.
Args:
original: Original video tensor [C, F, H, W] in [0, 1] or [-1, 1]
reconstructed: Reconstructed video tensor [C, F, H, W] in [0, 1] or [-1, 1]
output_path: Path to save the grid image
nrow: Number of frames per row
"""
# Get number of frames
F = min(original.shape[1], reconstructed.shape[1])
# Select frames to display
num_frames_to_show = min(8, F)
frame_indices = np.linspace(0, F - 1, num_frames_to_show, dtype=int)
frames_list = []
for idx in frame_indices:
# Original frame with label
orig_frame = original[:, idx, :, :].clone() # [C, H, W]
orig_frame = add_text_to_image(orig_frame, "Original", position=(10, 10))
frames_list.append(orig_frame)
# Reconstructed frame with label
recon_frame = reconstructed[:, idx, :, :].clone() # [C, H, W]
recon_frame = add_text_to_image(recon_frame, "Reconstructed", position=(10, 10))
frames_list.append(recon_frame)
# Create grid
frames_tensor = torch.stack(frames_list, dim=0)
grid = make_grid(frames_tensor, nrow=nrow * 2, padding=2, pad_value=1.0)
save_image(grid, output_path)
print(f"Saved comparison grid to: {output_path}")
def parse_args():
parser = argparse.ArgumentParser(description="Test VidTok tokenizer performance")
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to VidTok config file (e.g., configs/vidtok_kl_causal_488_4chn.yaml). "
"If not provided, will try to download from HuggingFace."
)
parser.add_argument(
"--ckpt",
type=str,
default=None,
help="Path to VidTok checkpoint file or HuggingFace model ID. "
"If HuggingFace model ID (e.g., microsoft/VidTok), will download automatically."
)
parser.add_argument(
"--model_name",
type=str,
default="vidtok_kl_causal_488_4chn",
help="VidTok model name for HuggingFace download. "
"Options: vidtok_kl_causal_488_4chn, vidtok_kl_noncausal_488_4chn, "
"vidtok_fsq_causal_488_4096, etc. (default: vidtok_kl_causal_488_4chn)"
)
parser.add_argument(
"--csv_path",
type=str,
required=True,
help="Path to OpenVid1M CSV file"
)
parser.add_argument(
"--video_root_dir",
type=str,
default=None,
help="Root directory for videos (auto-detected if not provided)"
)
parser.add_argument(
"--video_index",
type=int,
default=0,
help="Index of video to test (default: 0)"
)
parser.add_argument(
"--num_frames",
type=int,
default=16,
help="Number of frames (use 17 for causal models, 16 for non-causal)"
)
parser.add_argument(
"--height",
type=int,
default=256,
help="Video height"
)
parser.add_argument(
"--width",
type=int,
default=256,
help="Video width"
)
parser.add_argument(
"--output_dir",
type=str,
default="./vidtok_test_output",
help="Output directory for results"
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use"
)
parser.add_argument(
"--use_continuous",
action="store_true",
help="Use continuous latent space for decoding (default: use discrete tokens)"
)
return parser.parse_args()
def main():
args = parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Set device
device = torch.device(args.device)
print(f"Using device: {device}")
# Determine config and checkpoint paths
config_path = args.config
ckpt_path = args.ckpt
# If checkpoint is a HuggingFace model ID, download from HuggingFace
if ckpt_path is None or ckpt_path.startswith("microsoft/VidTok") or "/" in ckpt_path and not os.path.exists(ckpt_path):
print(f"Downloading VidTok model from HuggingFace...")
try:
from huggingface_hub import hf_hub_download, snapshot_download
import tempfile
# Determine model ID
if ckpt_path and ckpt_path.startswith("microsoft/VidTok"):
repo_id = ckpt_path
else:
repo_id = "microsoft/VidTok"
# Download checkpoint
checkpoint_filename = "vidtok_v1_1/vidtok_kl_causal_488_16chn_v1_1.ckpt"#f"{args.model_name}.ckpt"
print(f"Downloading checkpoint: {checkpoint_filename}")
# Create temporary directory for downloads
cache_dir = os.path.join(os.getcwd(), "vidtok_cache")
os.makedirs(cache_dir, exist_ok=True)
# Download checkpoint
ckpt_path = hf_hub_download(
repo_id=repo_id,
filename=f"checkpoints/{checkpoint_filename}",
cache_dir=cache_dir,
local_dir=os.path.join(cache_dir, repo_id.replace("/", "_")),
)
print(f"Downloaded checkpoint to: {ckpt_path}")
# Download config if not provided
config_path = "/mnt/Meissonic/VidTok/configs/vidtok_v1_1/vidtok_kl_causal_488_16chn_v1_1.yaml" #local path to config, no need to download
except ImportError:
print("Error: huggingface_hub not installed. Install with: pip install huggingface_hub")
sys.exit(1)
except Exception as e:
print(f"Error downloading from HuggingFace: {e}")
print("Please provide --config and --ckpt paths, or install VidTok repository.")
sys.exit(1)
# Load VidTok model
if not VIDTOK_AVAILABLE or load_model_from_config is None:
print("Error: VidTok scripts not available. Please install VidTok:")
print(" git clone https://github.com/microsoft/VidTok")
print(" export PYTHONPATH=\"${PYTHONPATH}:$(pwd)/VidTok\"")
print("\nOr ensure you have git installed for automatic download.")
sys.exit(1)
print(f"Loading VidTok model from config: {config_path}")
print(f"Checkpoint: {ckpt_path}")
if VIDTOK_PATH:
print(f"Using VidTok from: {VIDTOK_PATH}")
model = load_model_from_config(config_path, ckpt_path)
model = model.to(device).eval()
# model.encoder = torch.compile(model.encoder)
# model.decoder = torch.compile(model.decoder)
# Check if model is causal
is_causal = getattr(model, 'is_causal', False)
print(f"Model is causal: {is_causal}")
if is_causal and args.num_frames == 16:
print("Warning: Causal models typically use 17 frames. Consider using --num_frames 17")
# Load dataset
print(f"Loading dataset from: {args.csv_path}")
# Auto-detect video_root_dir if not provided
video_root_dir = args.video_root_dir
if video_root_dir is None:
csv_dir = os.path.dirname(args.csv_path)
if os.path.exists(os.path.join(csv_dir, 'video_reorg')):
video_root_dir = os.path.join(csv_dir, 'video_reorg')
elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')):
video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg')
else:
video_root_dir = csv_dir
print(f"Warning: Video directory not found, using CSV directory: {video_root_dir}")
# Initialize tokenizer for dataset (needed for OpenVid1MDataset)
tokenizer = T5Tokenizer.from_pretrained("google/umt5-base")
# Create dataset
dataset = OpenVid1MDataset(
csv_path=args.csv_path,
video_root_dir=video_root_dir,
tokenizer=tokenizer,
num_frames=args.num_frames,
height=args.height,
width=args.width,
text_encoder_architecture="umt5-base",
)
print(f"Dataset size: {len(dataset)}")
# Load video
if args.video_index >= len(dataset):
print(f"Error: video_index {args.video_index} >= dataset size {len(dataset)}")
return
print(f"Loading video at index {args.video_index}...")
sample = dataset[args.video_index]
original_video = sample["video"] # [C, F, H, W] in [0, 1]
# Get video info from dataset
row = dataset.data[args.video_index]
video_path = row.get('video', 'unknown')
caption = row.get('caption', 'no caption')
print(f"Video path: {video_path}")
print(f"Caption: {caption}")
print(f"Original video shape: {original_video.shape}")
print(f"Original video range: [{original_video.min():.3f}, {original_video.max():.3f}]")
# Convert to VidTok format: [B, C, T, H, W] in [-1, 1]
# Original video is [C, F, H, W] in [0, 1]
original_video_vidtok = original_video.unsqueeze(0) # [1, C, F, H, W]
original_video_vidtok = original_video_vidtok.permute(0, 1, 2, 3, 4) # [1, C, F, H, W] -> [1, C, T, H, W]
# Convert from [0, 1] to [-1, 1]
original_video_vidtok = original_video_vidtok * 2.0 - 1.0
original_video_vidtok = original_video_vidtok.to(device=device)
print(f"VidTok input shape: {original_video_vidtok.shape}")
print(f"VidTok input range: [{original_video_vidtok.min():.3f}, {original_video_vidtok.max():.3f}]")
# Encode
print("\nEncoding video...")
with torch.no_grad(), torch.autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', dtype=torch.float16 if device.type == 'cuda' else torch.float32):
if args.use_continuous:
# Encode to continuous latent space
z, reg_log = model.encode(original_video_vidtok, return_reg_log=True)
print(f"Continuous latent shape: {z.shape}")
print(f"Discrete tokens shape: {reg_log['indices'].shape if 'indices' in reg_log else 'N/A'}")
else:
# Full forward pass to get reconstruction
_, reconstructed_video_vidtok, _ = model(original_video_vidtok)
# Decode
if args.use_continuous:
print("\nDecoding from continuous latent space...")
with torch.no_grad(), torch.autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', dtype=torch.float16 if device.type == 'cuda' else torch.float32):
reconstructed_video_vidtok = model.decode(z)
else:
print("\nUsing reconstruction from forward pass...")
# Convert back to [C, F, H, W] format and [0, 1] range
reconstructed_video_vidtok = reconstructed_video_vidtok.squeeze(0) # [C, T, H, W]
reconstructed_video_vidtok = reconstructed_video_vidtok.permute(0, 1, 2, 3) # [C, T, H, W] -> [C, F, H, W]
# Convert from [-1, 1] to [0, 1]
reconstructed_video = (reconstructed_video_vidtok + 1.0) / 2.0
reconstructed_video = torch.clamp(reconstructed_video, 0, 1)
print(f"Reconstructed video shape: {reconstructed_video.shape}")
print(f"Reconstructed video range: [{reconstructed_video.min():.3f}, {reconstructed_video.max():.3f}]")
# Ensure same number of frames for comparison
F_orig = original_video.shape[1]
F_recon = reconstructed_video.shape[1]
F_min = min(F_orig, F_recon)
original_video = original_video[:, :F_min, :, :]
reconstructed_video = reconstructed_video[:, :F_min, :, :]
# Resize if spatial dimensions don't match
if original_video.shape[2:] != reconstructed_video.shape[2:]:
print(f"Resizing reconstructed video from {reconstructed_video.shape[2:]} to {original_video.shape[2:]}")
# Use interpolation to resize
reconstructed_video_resized = torch.zeros_like(original_video)
for f in range(F_min):
frame = reconstructed_video[:, f, :, :].unsqueeze(0) # [1, C, H, W]
frame_resized = torch.nn.functional.interpolate(
frame, size=original_video.shape[2:], mode='bilinear', align_corners=False
)
reconstructed_video_resized[:, f, :, :] = frame_resized.squeeze(0)
reconstructed_video = reconstructed_video_resized
# Calculate metrics
print("\nCalculating metrics...")
# Convert to float32 for metric calculation
orig_f32 = original_video.to(torch.float32).cpu()
recon_f32 = reconstructed_video.to(torch.float32).cpu()
# Frame-wise metrics
psnr_values = []
mse_values = []
ssim_values = []
for f in range(F_min):
orig_frame = orig_f32[:, f, :, :] # [C, H, W]
recon_frame = recon_f32[:, f, :, :] # [C, H, W]
psnr = calculate_psnr(orig_frame, recon_frame)
mse = calculate_mse(orig_frame, recon_frame)
ssim = calculate_ssim(orig_frame, recon_frame)
psnr_values.append(psnr)
mse_values.append(mse)
ssim_values.append(ssim)
# Overall metrics
avg_psnr = np.mean(psnr_values)
avg_mse = np.mean(mse_values)
avg_ssim = np.mean(ssim_values)
print(f"\n=== Metrics ===")
print(f"PSNR: {avg_psnr:.2f} dB (per frame: {psnr_values})")
print(f"MSE: {avg_mse:.6f} (per frame: {mse_values})")
print(f"SSIM: {avg_ssim:.4f} (per frame: {ssim_values})")
# Save metrics to file
metrics_file = os.path.join(args.output_dir, f"metrics_video_{args.video_index}.txt")
with open(metrics_file, 'w') as f:
f.write(f"Video Index: {args.video_index}\n")
f.write(f"Video Path: {video_path}\n")
f.write(f"Caption: {caption}\n")
f.write(f"Model Config: {args.config}\n")
f.write(f"Model Checkpoint: {args.ckpt}\n")
f.write(f"Use Continuous Latent: {args.use_continuous}\n")
f.write(f"\n=== Metrics ===\n")
f.write(f"Average PSNR: {avg_psnr:.2f} dB\n")
f.write(f"Average MSE: {avg_mse:.6f}\n")
f.write(f"Average SSIM: {avg_ssim:.4f}\n")
f.write(f"\nPer-frame PSNR: {psnr_values}\n")
f.write(f"Per-frame MSE: {mse_values}\n")
f.write(f"Per-frame SSIM: {ssim_values}\n")
print(f"Saved metrics to: {metrics_file}")
# Create side-by-side video
print("\nCreating side-by-side comparison video...")
video_output_path = os.path.join(args.output_dir, f"comparison_video_{args.video_index}.mp4")
create_side_by_side_video(original_video, reconstructed_video, video_output_path, fps=8)
# Create comparison grid
print("Creating comparison grid...")
grid_output_path = os.path.join(args.output_dir, f"comparison_grid_video_{args.video_index}.png")
create_comparison_grid(original_video, reconstructed_video, grid_output_path, nrow=4)
print(f"\n=== Test Complete ===")
print(f"Results saved to: {args.output_dir}")
print(f" - Metrics: {metrics_file}")
print(f" - Side-by-side video: {video_output_path}")
print(f" - Comparison grid: {grid_output_path}")
if __name__ == "__main__":
main()