#!/usr/bin/env python3 """ Check codebook range by iterating through videos and extracting codes. This script loads videos from the dataset, encodes them to get video codes, and tracks the min/max values to determine the codebook range. """ import argparse import os import sys import logging from tqdm import tqdm import torch import numpy as np sys.path.append(os.getcwd()) from train.dataset_utils import OpenVid1MDataset, PrecomputedFeatureDataset from src.pipeline_video import CosmosVideoTokenizer from transformers import T5Tokenizer from torch.utils.data import DataLoader logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Check codebook range from video dataset") parser.add_argument( "--csv_path", type=str, default=None, help="Path to OpenVid1M CSV file (if using raw videos)", ) parser.add_argument( "--video_root_dir", type=str, default=None, help="Root directory containing video files", ) parser.add_argument( "--features_dir", type=str, default=None, help="Directory containing pre-extracted features (if using precomputed features)", ) parser.add_argument( "--video_tokenizer_model_id", type=str, default="Cosmos-1.0-Tokenizer-DV8x16x16", help="HuggingFace model ID for Cosmos video tokenizer", ) parser.add_argument( "--num_frames", type=int, default=16, help="Number of frames per video", ) parser.add_argument( "--video_height", type=int, default=480, help="Video height", ) parser.add_argument( "--video_width", type=int, default=848, help="Video width", ) parser.add_argument( "--text_encoder_architecture", type=str, default="umt5-base", choices=["umt5-base", "umt5-xxl", "t5"], help="Text encoder architecture", ) parser.add_argument( "--batch_size", type=int, default=1, help="Batch size (use 1 for detailed per-sample tracking)", ) parser.add_argument( "--max_samples", type=int, default=None, help="Maximum number of samples to check. If None, check all.", ) parser.add_argument( "--check_interval", type=int, default=10, help="Print statistics every N samples", ) return parser.parse_args() def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 logger.info(f"Using device: {device}") # Initialize video tokenizer (only needed if not using precomputed features) video_tokenizer = None use_precomputed = args.features_dir is not None if not use_precomputed: if args.csv_path is None: raise ValueError("Either --csv_path or --features_dir must be provided") logger.info(f"Loading video tokenizer: {args.video_tokenizer_model_id}") video_tokenizer = CosmosVideoTokenizer( model_id=args.video_tokenizer_model_id, device=device, dtype=dtype ) video_tokenizer.requires_grad_(False) video_tokenizer.eval() # Get tokenizer info logger.info(f"Video tokenizer codebook_size: {video_tokenizer.codebook_size}") logger.info(f"Video tokenizer mask_token_id: {video_tokenizer.mask_token_id}") # Create dataset if use_precomputed: logger.info(f"Using precomputed features from: {args.features_dir}") dataset = PrecomputedFeatureDataset( features_dir=args.features_dir, num_samples=args.max_samples, ) else: # Auto-detect video_root_dir if not provided if args.video_root_dir is None: csv_dir = os.path.dirname(args.csv_path) if os.path.exists(os.path.join(csv_dir, 'video_reorg')): video_root_dir = os.path.join(csv_dir, 'video_reorg') elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')): video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg') else: video_root_dir = csv_dir logger.warning(f"Video directory not found, using CSV directory: {video_root_dir}") else: video_root_dir = args.video_root_dir # Create tokenizer for dataset if args.text_encoder_architecture == "umt5-base": model_id = "google/umt5-base" elif args.text_encoder_architecture == "umt5-xxl": model_id = "google/umt5-xxl" elif args.text_encoder_architecture == "t5": model_id = "t5-base" else: raise ValueError(f"Unknown text encoder: {args.text_encoder_architecture}") tokenizer = T5Tokenizer.from_pretrained(model_id) dataset = OpenVid1MDataset( csv_path=args.csv_path, video_root_dir=video_root_dir, tokenizer=tokenizer, num_frames=args.num_frames, height=args.video_height, width=args.video_width, text_encoder_architecture=args.text_encoder_architecture, use_random_temporal_crop=False, # Fixed sampling for consistency use_random_crop=False, # Center crop for consistency ) if args.max_samples is not None: dataset.data = dataset.data[:args.max_samples] logger.info(f"Limited dataset to {len(dataset)} samples") logger.info(f"Dataset size: {len(dataset)}") # Create dataloader dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, # Use 0 to avoid multiprocessing issues pin_memory=False, ) # Initialize statistics global_min = None global_max = None total_samples = 0 failed_samples = 0 logger.info("Starting to check codebook range...") logger.info("=" * 80) with torch.no_grad(): for batch_idx, batch in enumerate(tqdm(dataloader, desc="Checking codes")): try: if use_precomputed: # Use pre-extracted video codes video_codes = batch["video_codes"] # [B, F', H', W'] if isinstance(video_codes, torch.Tensor): video_codes = video_codes.long() else: video_codes = torch.from_numpy(video_codes).long() else: # Encode videos to get codes videos = batch["video"].to(device, non_blocking=True) # [B, C, F, H, W] video_codes = video_tokenizer.encode(videos) # [B, F', H', W'] video_codes = video_codes.cpu().long() # Update statistics batch_min = video_codes.min().item() batch_max = video_codes.max().item() if global_min is None: global_min = batch_min global_max = batch_max else: global_min = min(global_min, batch_min) global_max = max(global_max, batch_max) total_samples += video_codes.shape[0] # Print statistics periodically if (batch_idx + 1) % args.check_interval == 0 or batch_idx == 0: print(f"\n[Sample {total_samples}]") print(f" Current batch range: [{batch_min}, {batch_max}]") print(f" Global range so far: [{global_min}, {global_max}]") print(f" Codebook size (expected): {video_tokenizer.codebook_size if video_tokenizer else 'N/A'}") if video_tokenizer: expected_max = video_tokenizer.codebook_size - 1 print(f" Expected max (codebook_size - 1): {expected_max}") if global_max > expected_max: print(f" ⚠️ WARNING: Found code {global_max} > expected max {expected_max}!") if global_min < 0: print(f" ⚠️ WARNING: Found code {global_min} < 0!") # Print unique values count for current batch unique_values = torch.unique(video_codes).tolist() print(f" Unique values in batch: {len(unique_values)}") if len(unique_values) <= 20: print(f" Values: {sorted(unique_values)}") else: print(f" Min unique: {min(unique_values)}, Max unique: {max(unique_values)}") print("-" * 80) except Exception as e: failed_samples += args.batch_size logger.error(f"Failed to process batch {batch_idx}: {e}") continue # Final summary logger.info("=" * 80) logger.info("FINAL STATISTICS:") logger.info(f" Total samples processed: {total_samples}") logger.info(f" Failed samples: {failed_samples}") logger.info(f" Global min code: {global_min}") logger.info(f" Global max code: {global_max}") logger.info(f" Code range: [{global_min}, {global_max}]") if video_tokenizer: expected_max = video_tokenizer.codebook_size - 1 logger.info(f" Expected max (codebook_size - 1): {expected_max}") logger.info(f" Codebook size: {video_tokenizer.codebook_size}") logger.info(f" Mask token ID: {video_tokenizer.mask_token_id}") if global_max > expected_max: logger.warning(f" ⚠️ WARNING: Found code {global_max} > expected max {expected_max}!") elif global_max == expected_max: logger.info(f" ✓ Max code matches expected max") else: logger.info(f" Note: Max code {global_max} < expected max {expected_max} (some codes may not be used)") if global_min < 0: logger.warning(f" ⚠️ WARNING: Found code {global_min} < 0!") elif global_min == 0: logger.info(f" ✓ Min code is 0 (as expected)") else: logger.info(f" Note: Min code {global_min} > 0 (some codes may not be used)") logger.info("=" * 80) if __name__ == "__main__": main()