|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Test script for InfinityStar VQ-VAE performance. |
|
|
|
|
|
This script: |
|
|
1. Loads a video from the training dataset (same as test_cosmos_vqvae.py) |
|
|
2. Encodes it using InfinityStar VAE |
|
|
3. Decodes it back |
|
|
4. Computes metrics (PSNR, SSIM, MSE) - same as test_cosmos_vqvae.py |
|
|
5. Creates a side-by-side comparison video |
|
|
6. Saves the results |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
from torchvision import transforms |
|
|
from torchvision.utils import make_grid, save_image |
|
|
|
|
|
|
|
|
meissonic_path = "/mnt/Meissonic" |
|
|
if os.path.exists(meissonic_path): |
|
|
sys.path.insert(0, meissonic_path) |
|
|
|
|
|
meissonic_train_path = os.path.join(meissonic_path, "train") |
|
|
if os.path.exists(meissonic_train_path): |
|
|
sys.path.insert(0, meissonic_train_path) |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
|
|
|
|
|
|
class SimpleArgs: |
|
|
"""Simple replacement for Args class to avoid tap dependency.""" |
|
|
def __init__(self): |
|
|
|
|
|
self.semantic_scale_dim = 16 |
|
|
self.detail_scale_dim = 64 |
|
|
self.use_learnable_dim_proj = 0 |
|
|
self.detail_scale_min_tokens = 80 |
|
|
|
|
|
|
|
|
|
|
|
self.use_feat_proj = 2 |
|
|
self.semantic_scales = 8 |
|
|
|
|
|
self.vae_path = "" |
|
|
self.vae_type = 18 |
|
|
self.videovae = 10 |
|
|
|
|
|
|
|
|
import sys |
|
|
import importlib.util |
|
|
|
|
|
|
|
|
def load_visual_tokenizer_safe(args, device=None): |
|
|
"""Load visual tokenizer without importing arg_util.""" |
|
|
if not device: |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
if args.vae_type in [8,12,14,16,18,20,24,32,48,64,128]: |
|
|
schedule_mode = "dynamic" |
|
|
codebook_dim = args.vae_type |
|
|
print(f'Load VAE from {args.vae_path}') |
|
|
|
|
|
if args.videovae == 10: |
|
|
from infinity.models.videovae.models.load_vae_bsq_wan_absorb_patchify import video_vae_model |
|
|
vae_local = video_vae_model(args.vae_path, schedule_mode, codebook_dim, global_args=args, test_mode=True).to(device) |
|
|
else: |
|
|
raise ValueError(f"vae_type {args.vae_type} not supported") |
|
|
else: |
|
|
raise ValueError(f"vae_type {args.vae_type} not supported") |
|
|
return vae_local |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
import importlib.util |
|
|
dataset_utils_path = os.path.join(meissonic_path, "train", "dataset_utils.py") |
|
|
if os.path.exists(dataset_utils_path): |
|
|
spec = importlib.util.spec_from_file_location("meissonic_dataset_utils", dataset_utils_path) |
|
|
dataset_utils = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(dataset_utils) |
|
|
OpenVid1MDataset = dataset_utils.OpenVid1MDataset |
|
|
from transformers import T5Tokenizer |
|
|
DATASET_AVAILABLE = True |
|
|
print(f"Loaded dataset utilities from Meissonic: {dataset_utils_path}") |
|
|
else: |
|
|
raise ImportError(f"Could not find dataset_utils.py at {dataset_utils_path}") |
|
|
except Exception as e: |
|
|
DATASET_AVAILABLE = False |
|
|
print(f"Warning: Could not import dataset utilities: {e}") |
|
|
print("Will use direct video loading.") |
|
|
|
|
|
|
|
|
def calculate_psnr(img1, img2, max_val=1.0): |
|
|
"""Calculate PSNR between two images.""" |
|
|
|
|
|
img1 = img1.cpu() if isinstance(img1, torch.Tensor) else torch.tensor(img1) |
|
|
img2 = img2.cpu() if isinstance(img2, torch.Tensor) else torch.tensor(img2) |
|
|
|
|
|
mse = torch.mean((img1 - img2) ** 2) |
|
|
if mse == 0: |
|
|
return float('inf') |
|
|
psnr = 20 * torch.log10(max_val / torch.sqrt(mse)) |
|
|
return psnr.item() |
|
|
|
|
|
|
|
|
def calculate_mse(img1, img2): |
|
|
"""Calculate MSE between two images.""" |
|
|
|
|
|
img1 = img1.cpu() if isinstance(img1, torch.Tensor) else torch.tensor(img1) |
|
|
img2 = img2.cpu() if isinstance(img2, torch.Tensor) else torch.tensor(img2) |
|
|
|
|
|
return torch.mean((img1 - img2) ** 2).item() |
|
|
|
|
|
|
|
|
def calculate_ssim(img1, img2, window_size=11): |
|
|
"""Calculate SSIM between two images (simplified version).""" |
|
|
|
|
|
img1 = img1.cpu() if isinstance(img1, torch.Tensor) else torch.tensor(img1) |
|
|
img2 = img2.cpu() if isinstance(img2, torch.Tensor) else torch.tensor(img2) |
|
|
|
|
|
|
|
|
C1 = 0.01 ** 2 |
|
|
C2 = 0.03 ** 2 |
|
|
|
|
|
mu1 = img1.mean() |
|
|
mu2 = img2.mean() |
|
|
|
|
|
sigma1_sq = img1.var() |
|
|
sigma2_sq = img2.var() |
|
|
sigma12 = ((img1 - mu1) * (img2 - mu2)).mean() |
|
|
|
|
|
ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) |
|
|
return ssim.item() |
|
|
|
|
|
|
|
|
def video_to_numpy(video_tensor): |
|
|
""" |
|
|
Convert video tensor [C, F, H, W] in [0, 1] to numpy array [F, H, W, C] in [0, 255] (RGB). |
|
|
""" |
|
|
if isinstance(video_tensor, torch.Tensor): |
|
|
|
|
|
video_np = video_tensor.permute(1, 0, 2, 3).cpu().numpy() |
|
|
video_np = np.transpose(video_np, (0, 2, 3, 1)) |
|
|
|
|
|
video_np = np.clip(video_np, 0, 1) |
|
|
video_np = (video_np * 255).astype(np.uint8) |
|
|
else: |
|
|
video_np = np.array(video_tensor) |
|
|
return video_np |
|
|
|
|
|
|
|
|
def create_side_by_side_video(original, reconstructed, output_path, fps=8): |
|
|
""" |
|
|
Create a side-by-side comparison video. |
|
|
|
|
|
Args: |
|
|
original: Original video tensor [C, F, H, W] or numpy array |
|
|
reconstructed: Reconstructed video tensor [C, F, H, W] or numpy array |
|
|
output_path: Path to save the output video |
|
|
fps: Frames per second |
|
|
""" |
|
|
|
|
|
orig_np = video_to_numpy(original) |
|
|
recon_np = video_to_numpy(reconstructed) |
|
|
|
|
|
|
|
|
F, H, W, C = orig_np.shape |
|
|
F_recon, H_recon, W_recon, C_recon = recon_np.shape |
|
|
|
|
|
|
|
|
F_min = min(F, F_recon) |
|
|
orig_np = orig_np[:F_min] |
|
|
recon_np = recon_np[:F_min] |
|
|
|
|
|
|
|
|
if (H, W) != (H_recon, W_recon): |
|
|
recon_np = np.array([cv2.resize(frame, (W, H)) for frame in recon_np]) |
|
|
|
|
|
|
|
|
comparison_frames = [] |
|
|
for t in range(F_min): |
|
|
orig = orig_np[t] |
|
|
recon = recon_np[t] |
|
|
|
|
|
|
|
|
orig_labeled = orig.copy() |
|
|
recon_labeled = recon.copy() |
|
|
cv2.putText(orig_labeled, "Original", (10, 30), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) |
|
|
cv2.putText(recon_labeled, "Reconstructed", (10, 30), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2) |
|
|
|
|
|
|
|
|
side_by_side = np.concatenate([orig_labeled, recon_labeled], axis=1) |
|
|
comparison_frames.append(side_by_side) |
|
|
|
|
|
|
|
|
if len(comparison_frames) == 0: |
|
|
raise ValueError("No frames to save") |
|
|
|
|
|
height, width = comparison_frames[0].shape[:2] |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
for frame in comparison_frames: |
|
|
|
|
|
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
|
out.write(frame_bgr) |
|
|
|
|
|
out.release() |
|
|
print(f"Saved side-by-side video to: {output_path}") |
|
|
|
|
|
|
|
|
def add_text_to_image(image_tensor, text, position=(10, 30)): |
|
|
""" |
|
|
Add text label to an image tensor. |
|
|
|
|
|
Args: |
|
|
image_tensor: Image tensor [C, H, W] in [0, 1] |
|
|
text: Text to add |
|
|
position: (x, y) position for text |
|
|
Returns: |
|
|
Image tensor with text [C, H, W] |
|
|
""" |
|
|
|
|
|
image_np = image_tensor.permute(1, 2, 0).cpu().numpy() |
|
|
image_np = np.clip(image_np, 0, 1) |
|
|
image_np = (image_np * 255).astype(np.uint8) |
|
|
pil_image = Image.fromarray(image_np) |
|
|
|
|
|
|
|
|
from PIL import ImageDraw, ImageFont |
|
|
draw = ImageDraw.Draw(pil_image) |
|
|
try: |
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 24) |
|
|
except: |
|
|
try: |
|
|
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 24) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
|
|
|
x, y = position |
|
|
|
|
|
for adj in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]: |
|
|
draw.text((x + adj[0], y + adj[1]), text, font=font, fill=(0, 0, 0)) |
|
|
|
|
|
draw.text((x, y), text, font=font, fill=(255, 255, 255)) |
|
|
|
|
|
|
|
|
image_tensor = transforms.ToTensor()(pil_image) |
|
|
return image_tensor |
|
|
|
|
|
|
|
|
def create_comparison_grid(original, reconstructed, output_path, nrow=4): |
|
|
""" |
|
|
Create a grid image comparing original and reconstructed frames. |
|
|
|
|
|
Args: |
|
|
original: Original video tensor [C, F, H, W] |
|
|
reconstructed: Reconstructed video tensor [C, F, H, W] |
|
|
output_path: Path to save the grid image |
|
|
nrow: Number of frames per row |
|
|
""" |
|
|
|
|
|
F = min(original.shape[1], reconstructed.shape[1]) |
|
|
|
|
|
|
|
|
num_frames_to_show = min(8, F) |
|
|
frame_indices = np.linspace(0, F - 1, num_frames_to_show, dtype=int) |
|
|
|
|
|
frames_list = [] |
|
|
for idx in frame_indices: |
|
|
|
|
|
orig_frame = original[:, idx, :, :].clone() |
|
|
orig_frame = add_text_to_image(orig_frame, "Original", position=(10, 10)) |
|
|
frames_list.append(orig_frame) |
|
|
|
|
|
|
|
|
recon_frame = reconstructed[:, idx, :, :].clone() |
|
|
recon_frame = add_text_to_image(recon_frame, "Reconstructed", position=(10, 10)) |
|
|
frames_list.append(recon_frame) |
|
|
|
|
|
|
|
|
frames_tensor = torch.stack(frames_list, dim=0) |
|
|
grid = make_grid(frames_tensor, nrow=nrow * 2, padding=2, pad_value=1.0) |
|
|
|
|
|
save_image(grid, output_path) |
|
|
print(f"Saved comparison grid to: {output_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
VAE_PATH = "/mnt/Meissonic/InfinityStar/infinitystar_videovae.pth" |
|
|
VAE_TYPE = 18 |
|
|
VIDEOVAE = 10 |
|
|
|
|
|
|
|
|
CSV_PATH = "/mnt/VideoGen/dataset/OpenVid1M/video_reorg/OpenVid1M_reorganized.csv" |
|
|
VIDEO_ROOT_DIR = None |
|
|
VIDEO_INDEX = 3 |
|
|
|
|
|
|
|
|
NUM_FRAMES = 16 |
|
|
HEIGHT = 480 |
|
|
WIDTH = 848 |
|
|
|
|
|
|
|
|
OUTPUT_DIR = "./infinity_vqvae_test_output" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
DTYPE = "float32" |
|
|
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
device = torch.device(DEVICE) |
|
|
if DTYPE == "float16": |
|
|
dtype = torch.float16 |
|
|
elif DTYPE == "bfloat16": |
|
|
dtype = torch.bfloat16 |
|
|
else: |
|
|
dtype = torch.float32 |
|
|
|
|
|
print(f"Using device: {device}, dtype: {dtype}") |
|
|
|
|
|
|
|
|
print("=" * 80) |
|
|
print("Loading VQ-VAE model...") |
|
|
print(f" VAE path: {VAE_PATH}") |
|
|
print(f" VAE type: {VAE_TYPE}") |
|
|
print(f" Video VAE: {VIDEOVAE}") |
|
|
print("=" * 80) |
|
|
|
|
|
vae_args = SimpleArgs() |
|
|
vae_args.vae_path = VAE_PATH |
|
|
vae_args.vae_type = VAE_TYPE |
|
|
vae_args.videovae = VIDEOVAE |
|
|
|
|
|
vae = load_visual_tokenizer_safe(vae_args, device=device) |
|
|
vae = vae.to(device) |
|
|
vae.eval() |
|
|
|
|
|
[p.requires_grad_(False) for p in vae.parameters()] |
|
|
|
|
|
print("VAE loaded successfully!") |
|
|
print(f" Device: {device}") |
|
|
print(f" Model dtype: {next(vae.parameters()).dtype}") |
|
|
print(f" Model in eval mode: {not vae.training}") |
|
|
|
|
|
|
|
|
if DATASET_AVAILABLE: |
|
|
print(f"\nLoading dataset from: {CSV_PATH}") |
|
|
|
|
|
|
|
|
video_root_dir = VIDEO_ROOT_DIR |
|
|
if video_root_dir is None: |
|
|
csv_dir = os.path.dirname(CSV_PATH) |
|
|
if os.path.exists(os.path.join(csv_dir, 'video_reorg')): |
|
|
video_root_dir = os.path.join(csv_dir, 'video_reorg') |
|
|
elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')): |
|
|
video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg') |
|
|
else: |
|
|
video_root_dir = csv_dir |
|
|
print(f"Warning: Video directory not found, using CSV directory: {video_root_dir}") |
|
|
|
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("google/umt5-base") |
|
|
|
|
|
|
|
|
dataset = OpenVid1MDataset( |
|
|
csv_path=CSV_PATH, |
|
|
video_root_dir=video_root_dir, |
|
|
tokenizer=tokenizer, |
|
|
num_frames=NUM_FRAMES, |
|
|
height=HEIGHT, |
|
|
width=WIDTH, |
|
|
text_encoder_architecture="umt5-base", |
|
|
) |
|
|
|
|
|
print(f"Dataset size: {len(dataset)}") |
|
|
|
|
|
|
|
|
if VIDEO_INDEX >= len(dataset): |
|
|
print(f"Error: video_index {VIDEO_INDEX} >= dataset size {len(dataset)}") |
|
|
return |
|
|
|
|
|
print(f"Loading video at index {VIDEO_INDEX}...") |
|
|
sample = dataset[VIDEO_INDEX] |
|
|
original_video = sample["video"] |
|
|
|
|
|
|
|
|
if original_video.dim() == 4: |
|
|
|
|
|
if original_video.shape[0] == NUM_FRAMES and original_video.shape[1] == 3: |
|
|
print(f"Detected [T, C, H, W] format, converting to [C, T, H, W]") |
|
|
original_video = original_video.permute(1, 0, 2, 3) |
|
|
|
|
|
elif original_video.shape[-1] == 3: |
|
|
print(f"Detected [T, H, W, C] format, converting to [C, T, H, W]") |
|
|
original_video = original_video.permute(3, 0, 1, 2) |
|
|
|
|
|
|
|
|
row = dataset.data[VIDEO_INDEX] |
|
|
video_path = row.get('video', 'unknown') |
|
|
caption = row.get('caption', 'no caption') |
|
|
|
|
|
print(f"Video path: {video_path}") |
|
|
print(f"Caption: {caption}") |
|
|
else: |
|
|
print("Warning: Dataset utilities not available. Using dummy video.") |
|
|
original_video = torch.rand(3, NUM_FRAMES, HEIGHT, WIDTH) |
|
|
video_path = "dummy" |
|
|
caption = "dummy video" |
|
|
|
|
|
print(f"Original video shape (C, T, H, W): {original_video.shape}") |
|
|
print(f"Original video range (from dataset): [{original_video.min():.3f}, {original_video.max():.3f}]") |
|
|
|
|
|
|
|
|
video_for_vae = original_video.to(device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
|
|
|
video_for_vae = video_for_vae.clamp(0.0, 1.0) |
|
|
print("Dataset returns [0, 1], converting to [-1, 1] for VAE") |
|
|
video_for_vae = video_for_vae * 2.0 - 1.0 |
|
|
|
|
|
print(f"Video for VAE range: [{video_for_vae.min():.3f}, {video_for_vae.max():.3f}]") |
|
|
|
|
|
|
|
|
video_for_vae = video_for_vae.unsqueeze(0) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Encoding using vae.encode_for_raw_features (InfinityStar's method)...") |
|
|
print("=" * 80) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
raw_features, _, _ = vae.encode_for_raw_features( |
|
|
video_for_vae, |
|
|
scale_schedule=None, |
|
|
slice=True |
|
|
) |
|
|
print(f"Encoded latent shape: {raw_features.shape}") |
|
|
print(f"Encoded latent range: [{raw_features.min().item():.4f}, {raw_features.max().item():.4f}]") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Decoding using vae.decode (InfinityStar's method)...") |
|
|
print("=" * 80) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
reconstructed_video_batch = vae.decode(raw_features, slice=True) |
|
|
if isinstance(reconstructed_video_batch, tuple): |
|
|
reconstructed_video_batch = reconstructed_video_batch[0] |
|
|
|
|
|
|
|
|
reconstructed_video_batch = torch.clamp(reconstructed_video_batch, min=-1, max=1) |
|
|
|
|
|
print(f"Reconstructed shape: {reconstructed_video_batch.shape}") |
|
|
print(f"Reconstructed range: [{reconstructed_video_batch.min():.3f}, {reconstructed_video_batch.max():.3f}]") |
|
|
|
|
|
|
|
|
reconstructed_video = reconstructed_video_batch.squeeze(0) |
|
|
|
|
|
|
|
|
|
|
|
if reconstructed_video.min() < 0: |
|
|
print("Reconstructed video is in [-1, 1], converting to [0, 1]") |
|
|
reconstructed_video_01 = (reconstructed_video + 1.0) / 2.0 |
|
|
else: |
|
|
print("Reconstructed video is already in [0, 1]") |
|
|
reconstructed_video_01 = reconstructed_video.clone() |
|
|
reconstructed_video_01 = torch.clamp(reconstructed_video_01, 0, 1) |
|
|
print(f"Reconstructed video [0, 1] range: [{reconstructed_video_01.min():.3f}, {reconstructed_video_01.max():.3f}]") |
|
|
|
|
|
|
|
|
original_video_01 = original_video.clone().to(device=device) |
|
|
if original_video_01.min() < 0: |
|
|
original_video_01 = (original_video_01 + 1.0) / 2.0 |
|
|
elif original_video_01.max() > 1.0: |
|
|
original_video_01 = original_video_01 / 255.0 |
|
|
original_video_01 = torch.clamp(original_video_01, 0, 1) |
|
|
print(f"Original video [0, 1] range: [{original_video_01.min():.3f}, {original_video_01.max():.3f}]") |
|
|
|
|
|
|
|
|
F_orig = original_video_01.shape[1] |
|
|
F_recon = reconstructed_video_01.shape[1] |
|
|
F_min = min(F_orig, F_recon) |
|
|
|
|
|
if F_orig != F_recon: |
|
|
print(f"Frame count mismatch: original={F_orig}, reconstructed={F_recon}, using first {F_min} frames for comparison") |
|
|
print(" (This is normal for VAE with temporal compression)") |
|
|
|
|
|
original_video_01 = original_video_01[:, :F_min, :, :] |
|
|
reconstructed_video_01 = reconstructed_video_01[:, :F_min, :, :] |
|
|
|
|
|
|
|
|
if original_video_01.shape[2:] != reconstructed_video_01.shape[2:]: |
|
|
print(f"Resizing reconstructed video from {reconstructed_video_01.shape[2:]} to {original_video_01.shape[2:]}") |
|
|
|
|
|
reconstructed_video_resized = torch.zeros_like(original_video_01) |
|
|
for f in range(F_min): |
|
|
frame = reconstructed_video_01[:, f, :, :].unsqueeze(0) |
|
|
frame_resized = torch.nn.functional.interpolate( |
|
|
frame, size=original_video_01.shape[2:], mode='bilinear', align_corners=False |
|
|
) |
|
|
reconstructed_video_resized[:, f, :, :] = frame_resized.squeeze(0) |
|
|
reconstructed_video_01 = reconstructed_video_resized |
|
|
|
|
|
|
|
|
print("\nCalculating metrics...") |
|
|
|
|
|
|
|
|
orig_f32 = original_video_01.to(torch.float32) |
|
|
recon_f32 = reconstructed_video_01.to(torch.float32) |
|
|
|
|
|
|
|
|
psnr_values = [] |
|
|
mse_values = [] |
|
|
ssim_values = [] |
|
|
|
|
|
for f in range(F_min): |
|
|
orig_frame = orig_f32[:, f, :, :] |
|
|
recon_frame = recon_f32[:, f, :, :] |
|
|
|
|
|
psnr = calculate_psnr(orig_frame, recon_frame) |
|
|
mse = calculate_mse(orig_frame, recon_frame) |
|
|
ssim = calculate_ssim(orig_frame, recon_frame) |
|
|
|
|
|
psnr_values.append(psnr) |
|
|
mse_values.append(mse) |
|
|
ssim_values.append(ssim) |
|
|
|
|
|
|
|
|
avg_psnr = np.mean(psnr_values) |
|
|
avg_mse = np.mean(mse_values) |
|
|
avg_ssim = np.mean(ssim_values) |
|
|
|
|
|
print(f"\n=== Metrics ===") |
|
|
print(f"PSNR: {avg_psnr:.2f} dB (per frame: {psnr_values})") |
|
|
print(f"MSE: {avg_mse:.6f} (per frame: {mse_values})") |
|
|
print(f"SSIM: {avg_ssim:.4f} (per frame: {ssim_values})") |
|
|
|
|
|
|
|
|
metrics_file = os.path.join(OUTPUT_DIR, f"metrics_video_{VIDEO_INDEX}.txt") |
|
|
with open(metrics_file, 'w') as f: |
|
|
f.write(f"Video Index: {VIDEO_INDEX}\n") |
|
|
f.write(f"Video Path: {video_path}\n") |
|
|
f.write(f"Caption: {caption}\n") |
|
|
f.write(f"\n=== Metrics ===\n") |
|
|
f.write(f"Average PSNR: {avg_psnr:.2f} dB\n") |
|
|
f.write(f"Average MSE: {avg_mse:.6f}\n") |
|
|
f.write(f"Average SSIM: {avg_ssim:.4f}\n") |
|
|
f.write(f"\nPer-frame PSNR: {psnr_values}\n") |
|
|
f.write(f"Per-frame MSE: {mse_values}\n") |
|
|
f.write(f"Per-frame SSIM: {ssim_values}\n") |
|
|
|
|
|
print(f"Saved metrics to: {metrics_file}") |
|
|
|
|
|
|
|
|
print("\nCreating side-by-side comparison video...") |
|
|
video_output_path = os.path.join(OUTPUT_DIR, f"comparison_video_{VIDEO_INDEX}.mp4") |
|
|
create_side_by_side_video(original_video_01, reconstructed_video_01, video_output_path, fps=8) |
|
|
|
|
|
|
|
|
print("Creating comparison grid...") |
|
|
grid_output_path = os.path.join(OUTPUT_DIR, f"comparison_grid_video_{VIDEO_INDEX}.png") |
|
|
create_comparison_grid(original_video_01, reconstructed_video_01, grid_output_path, nrow=4) |
|
|
|
|
|
print(f"\n=== Test Complete ===") |
|
|
print(f"Results saved to: {OUTPUT_DIR}") |
|
|
print(f" - Metrics: {metrics_file}") |
|
|
print(f" - Side-by-side video: {video_output_path}") |
|
|
print(f" - Comparison grid: {grid_output_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|