| |
| """ |
| Calculate FID (Fréchet Inception Distance) between predicted and ground truth videos. |
| |
| Usage: |
| python calculate_fid.py --videos_dir /path/to/videos |
| python calculate_fid.py --videos_dir /path/to/videos --batch_size 32 |
| """ |
|
|
| import torch |
| import numpy as np |
| from pathlib import Path |
| from tqdm import tqdm |
| import argparse |
| import cv2 |
| from torchmetrics.image.fid import FrechetInceptionDistance |
|
|
|
|
| def load_video_frames(video_path, max_frames=None): |
| """ |
| Load frames from a video file. |
| |
| Args: |
| video_path: Path to the video file |
| max_frames: Maximum number of frames to load (None = all frames) |
| |
| Returns: |
| torch.Tensor: Video frames with shape (T, C, H, W) in range [0, 255] |
| """ |
| cap = cv2.VideoCapture(str(video_path)) |
| frames = [] |
| frame_count = 0 |
| |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| |
| |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frames.append(frame) |
| frame_count += 1 |
| |
| if max_frames and frame_count >= max_frames: |
| break |
| |
| cap.release() |
| |
| if len(frames) == 0: |
| raise ValueError(f"No frames loaded from {video_path}") |
| |
| |
| frames = np.stack(frames, axis=0) |
| frames = torch.from_numpy(frames).permute(0, 3, 1, 2) |
| |
| return frames |
|
|
|
|
| def load_videos_from_directory(video_dir, max_frames_per_video=None, max_videos=None): |
| """ |
| Load all videos from a directory. |
| |
| Args: |
| video_dir: Directory containing .mp4 files |
| max_frames_per_video: Maximum frames to load per video |
| max_videos: Maximum number of videos to load |
| |
| Returns: |
| torch.Tensor: All frames concatenated with shape (N, C, H, W) |
| """ |
| video_dir = Path(video_dir) |
| video_paths = sorted(list(video_dir.glob("**/*.mp4"))) |
| |
| if max_videos: |
| video_paths = video_paths[:max_videos] |
| |
| all_frames = [] |
| |
| print(f"Loading videos from {video_dir}") |
| print(f"Found {len(video_paths)} videos") |
| |
| for video_path in tqdm(video_paths, desc="Loading videos"): |
| try: |
| frames = load_video_frames(video_path, max_frames=max_frames_per_video) |
| all_frames.append(frames) |
| except Exception as e: |
| print(f"\nWarning: Failed to load {video_path.name}: {e}") |
| continue |
| |
| if len(all_frames) == 0: |
| raise ValueError(f"No videos loaded from {video_dir}") |
| |
| |
| all_frames = torch.cat(all_frames, dim=0) |
| |
| print(f"Loaded {all_frames.shape[0]} frames total") |
| print(f"Frame shape: {all_frames.shape[1:]}") |
| |
| return all_frames |
|
|
|
|
| def calculate_fid(pred_dir, gt_dir, batch_size=32, device='cuda', |
| max_frames_per_video=None, max_videos=None): |
| """ |
| Calculate FID between predicted and ground truth videos. |
| |
| Args: |
| pred_dir: Directory containing predicted videos |
| gt_dir: Directory containing ground truth videos |
| batch_size: Batch size for FID calculation |
| device: Device to use ('cuda' or 'cpu') |
| max_frames_per_video: Maximum frames to load per video |
| max_videos: Maximum number of videos to load from each directory |
| |
| Returns: |
| float: FID score |
| """ |
| print("="*60) |
| print("FID Calculation") |
| print("="*60) |
| print(f"Pred directory: {pred_dir}") |
| print(f"GT directory: {gt_dir}") |
| print(f"Device: {device}") |
| print(f"Batch size: {batch_size}") |
| print("="*60 + "\n") |
| |
| |
| pred_dir = Path(pred_dir) |
| gt_dir = Path(gt_dir) |
| |
| if not pred_dir.exists(): |
| raise ValueError(f"Pred directory does not exist: {pred_dir}") |
| if not gt_dir.exists(): |
| raise ValueError(f"GT directory does not exist: {gt_dir}") |
| |
| |
| print("\n[1/3] Loading predicted videos...") |
| pred_frames = load_videos_from_directory( |
| pred_dir, |
| max_frames_per_video=max_frames_per_video, |
| max_videos=max_videos |
| ) |
| |
| print("\n[2/3] Loading ground truth videos...") |
| gt_frames = load_videos_from_directory( |
| gt_dir, |
| max_frames_per_video=max_frames_per_video, |
| max_videos=max_videos |
| ) |
| |
| |
| print("\n[3/3] Calculating FID...") |
| fid_model = FrechetInceptionDistance(normalize=True).to(device) |
| |
| |
| print("Processing predicted frames...") |
| num_pred_frames = pred_frames.shape[0] |
| for i in tqdm(range(0, num_pred_frames, batch_size)): |
| batch = pred_frames[i:i+batch_size] |
| batch = batch.to(device) |
| fid_model.update(batch, real=False) |
| |
| |
| print("Processing ground truth frames...") |
| num_gt_frames = gt_frames.shape[0] |
| for i in tqdm(range(0, num_gt_frames, batch_size)): |
| batch = gt_frames[i:i+batch_size] |
| batch = batch.to(device) |
| fid_model.update(batch, real=True) |
| |
| |
| fid_score = fid_model.compute().item() |
| |
| return fid_score |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Calculate FID between predicted and ground truth videos" |
| ) |
| parser.add_argument( |
| "--videos_dir", |
| type=str, |
| default="/mnt/worldmem_valid/outputs/2025-12-01/08-09-46/videos/test_vis", |
| help="Base directory containing 'pred' and 'gt' subdirectories" |
| ) |
| parser.add_argument( |
| "--pred_dir", |
| type=str, |
| default=None, |
| help="Override pred directory (default: {videos_dir}/pred)" |
| ) |
| parser.add_argument( |
| "--gt_dir", |
| type=str, |
| default=None, |
| help="Override gt directory (default: {videos_dir}/gt)" |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=32, |
| help="Batch size for FID calculation (default: 32)" |
| ) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cuda" if torch.cuda.is_available() else "cpu", |
| help="Device to use (default: cuda if available)" |
| ) |
| parser.add_argument( |
| "--max_frames_per_video", |
| type=int, |
| default=None, |
| help="Maximum frames to load per video (default: None, load all)" |
| ) |
| parser.add_argument( |
| "--max_videos", |
| type=int, |
| default=50, |
| help="Maximum number of videos to load (default: None, load all)" |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| videos_dir = Path(args.videos_dir) |
| |
| if args.pred_dir: |
| pred_dir = Path(args.pred_dir) |
| else: |
| pred_dir = videos_dir / "pred" |
| |
| if args.gt_dir: |
| gt_dir = Path(args.gt_dir) |
| else: |
| gt_dir = videos_dir / "gt" |
| |
| |
| try: |
| fid_score = calculate_fid( |
| pred_dir=pred_dir, |
| gt_dir=gt_dir, |
| batch_size=args.batch_size, |
| device=args.device, |
| max_frames_per_video=args.max_frames_per_video, |
| max_videos=args.max_videos |
| ) |
| |
| |
| print("\n" + "="*60) |
| print("RESULTS") |
| print("="*60) |
| print(f"FID Score: {fid_score:.4f}") |
| print("="*60) |
| |
| |
| output_file = videos_dir / "fid_results.txt" |
| with open(output_file, 'w') as f: |
| f.write(f"FID Score: {fid_score:.4f}\n") |
| f.write(f"Pred directory: {pred_dir}\n") |
| f.write(f"GT directory: {gt_dir}\n") |
| |
| print(f"\nResults saved to: {output_file}") |
| |
| except Exception as e: |
| print(f"\n✗ Error: {e}") |
| import traceback |
| traceback.print_exc() |
| return 1 |
| |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |
|
|
|
|