WorldMem_Repro / calculate_fid.py
BonanDing's picture
update lfs
8652b14
#!/usr/bin/env python3
"""
Calculate FID (Fréchet Inception Distance) between predicted and ground truth videos.
Usage:
python calculate_fid.py --videos_dir /path/to/videos
python calculate_fid.py --videos_dir /path/to/videos --batch_size 32
"""
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
import argparse
import cv2
from torchmetrics.image.fid import FrechetInceptionDistance
def load_video_frames(video_path, max_frames=None):
"""
Load frames from a video file.
Args:
video_path: Path to the video file
max_frames: Maximum number of frames to load (None = all frames)
Returns:
torch.Tensor: Video frames with shape (T, C, H, W) in range [0, 255]
"""
cap = cv2.VideoCapture(str(video_path))
frames = []
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
frame_count += 1
if max_frames and frame_count >= max_frames:
break
cap.release()
if len(frames) == 0:
raise ValueError(f"No frames loaded from {video_path}")
# Convert to tensor: (T, H, W, C) -> (T, C, H, W)
frames = np.stack(frames, axis=0)
frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
return frames
def load_videos_from_directory(video_dir, max_frames_per_video=None, max_videos=None):
"""
Load all videos from a directory.
Args:
video_dir: Directory containing .mp4 files
max_frames_per_video: Maximum frames to load per video
max_videos: Maximum number of videos to load
Returns:
torch.Tensor: All frames concatenated with shape (N, C, H, W)
"""
video_dir = Path(video_dir)
video_paths = sorted(list(video_dir.glob("**/*.mp4")))
if max_videos:
video_paths = video_paths[:max_videos]
all_frames = []
print(f"Loading videos from {video_dir}")
print(f"Found {len(video_paths)} videos")
for video_path in tqdm(video_paths, desc="Loading videos"):
try:
frames = load_video_frames(video_path, max_frames=max_frames_per_video)
all_frames.append(frames)
except Exception as e:
print(f"\nWarning: Failed to load {video_path.name}: {e}")
continue
if len(all_frames) == 0:
raise ValueError(f"No videos loaded from {video_dir}")
# Concatenate all frames: (N_videos, T, C, H, W) -> (N_total_frames, C, H, W)
all_frames = torch.cat(all_frames, dim=0)
print(f"Loaded {all_frames.shape[0]} frames total")
print(f"Frame shape: {all_frames.shape[1:]}")
return all_frames
def calculate_fid(pred_dir, gt_dir, batch_size=32, device='cuda',
max_frames_per_video=None, max_videos=None):
"""
Calculate FID between predicted and ground truth videos.
Args:
pred_dir: Directory containing predicted videos
gt_dir: Directory containing ground truth videos
batch_size: Batch size for FID calculation
device: Device to use ('cuda' or 'cpu')
max_frames_per_video: Maximum frames to load per video
max_videos: Maximum number of videos to load from each directory
Returns:
float: FID score
"""
print("="*60)
print("FID Calculation")
print("="*60)
print(f"Pred directory: {pred_dir}")
print(f"GT directory: {gt_dir}")
print(f"Device: {device}")
print(f"Batch size: {batch_size}")
print("="*60 + "\n")
# Check if directories exist
pred_dir = Path(pred_dir)
gt_dir = Path(gt_dir)
if not pred_dir.exists():
raise ValueError(f"Pred directory does not exist: {pred_dir}")
if not gt_dir.exists():
raise ValueError(f"GT directory does not exist: {gt_dir}")
# Load videos
print("\n[1/3] Loading predicted videos...")
pred_frames = load_videos_from_directory(
pred_dir,
max_frames_per_video=max_frames_per_video,
max_videos=max_videos
)
print("\n[2/3] Loading ground truth videos...")
gt_frames = load_videos_from_directory(
gt_dir,
max_frames_per_video=max_frames_per_video,
max_videos=max_videos
)
# Initialize FID model
print("\n[3/3] Calculating FID...")
fid_model = FrechetInceptionDistance(normalize=True).to(device)
# Process pred frames in batches
print("Processing predicted frames...")
num_pred_frames = pred_frames.shape[0]
for i in tqdm(range(0, num_pred_frames, batch_size)):
batch = pred_frames[i:i+batch_size]
batch = batch.to(device)
fid_model.update(batch, real=False)
# Process gt frames in batches
print("Processing ground truth frames...")
num_gt_frames = gt_frames.shape[0]
for i in tqdm(range(0, num_gt_frames, batch_size)):
batch = gt_frames[i:i+batch_size]
batch = batch.to(device)
fid_model.update(batch, real=True)
# Compute FID
fid_score = fid_model.compute().item()
return fid_score
def main():
parser = argparse.ArgumentParser(
description="Calculate FID between predicted and ground truth videos"
)
parser.add_argument(
"--videos_dir",
type=str,
default="/mnt/worldmem_valid/outputs/2025-12-01/08-09-46/videos/test_vis",
help="Base directory containing 'pred' and 'gt' subdirectories"
)
parser.add_argument(
"--pred_dir",
type=str,
default=None,
help="Override pred directory (default: {videos_dir}/pred)"
)
parser.add_argument(
"--gt_dir",
type=str,
default=None,
help="Override gt directory (default: {videos_dir}/gt)"
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size for FID calculation (default: 32)"
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use (default: cuda if available)"
)
parser.add_argument(
"--max_frames_per_video",
type=int,
default=None,
help="Maximum frames to load per video (default: None, load all)"
)
parser.add_argument(
"--max_videos",
type=int,
default=50,
help="Maximum number of videos to load (default: None, load all)"
)
args = parser.parse_args()
# Determine pred and gt directories
videos_dir = Path(args.videos_dir)
if args.pred_dir:
pred_dir = Path(args.pred_dir)
else:
pred_dir = videos_dir / "pred"
if args.gt_dir:
gt_dir = Path(args.gt_dir)
else:
gt_dir = videos_dir / "gt"
# Calculate FID
try:
fid_score = calculate_fid(
pred_dir=pred_dir,
gt_dir=gt_dir,
batch_size=args.batch_size,
device=args.device,
max_frames_per_video=args.max_frames_per_video,
max_videos=args.max_videos
)
# Print results
print("\n" + "="*60)
print("RESULTS")
print("="*60)
print(f"FID Score: {fid_score:.4f}")
print("="*60)
# Save results to file
output_file = videos_dir / "fid_results.txt"
with open(output_file, 'w') as f:
f.write(f"FID Score: {fid_score:.4f}\n")
f.write(f"Pred directory: {pred_dir}\n")
f.write(f"GT directory: {gt_dir}\n")
print(f"\nResults saved to: {output_file}")
except Exception as e:
print(f"\n✗ Error: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())