43 / Meissonic /train /check_codebook_range.py
BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
#!/usr/bin/env python3
"""
Check codebook range by iterating through videos and extracting codes.
This script loads videos from the dataset, encodes them to get video codes,
and tracks the min/max values to determine the codebook range.
"""
import argparse
import os
import sys
import logging
from tqdm import tqdm
import torch
import numpy as np
sys.path.append(os.getcwd())
from train.dataset_utils import OpenVid1MDataset, PrecomputedFeatureDataset
from src.pipeline_video import CosmosVideoTokenizer
from transformers import T5Tokenizer
from torch.utils.data import DataLoader
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Check codebook range from video dataset")
parser.add_argument(
"--csv_path",
type=str,
default=None,
help="Path to OpenVid1M CSV file (if using raw videos)",
)
parser.add_argument(
"--video_root_dir",
type=str,
default=None,
help="Root directory containing video files",
)
parser.add_argument(
"--features_dir",
type=str,
default=None,
help="Directory containing pre-extracted features (if using precomputed features)",
)
parser.add_argument(
"--video_tokenizer_model_id",
type=str,
default="Cosmos-1.0-Tokenizer-DV8x16x16",
help="HuggingFace model ID for Cosmos video tokenizer",
)
parser.add_argument(
"--num_frames",
type=int,
default=16,
help="Number of frames per video",
)
parser.add_argument(
"--video_height",
type=int,
default=480,
help="Video height",
)
parser.add_argument(
"--video_width",
type=int,
default=848,
help="Video width",
)
parser.add_argument(
"--text_encoder_architecture",
type=str,
default="umt5-base",
choices=["umt5-base", "umt5-xxl", "t5"],
help="Text encoder architecture",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (use 1 for detailed per-sample tracking)",
)
parser.add_argument(
"--max_samples",
type=int,
default=None,
help="Maximum number of samples to check. If None, check all.",
)
parser.add_argument(
"--check_interval",
type=int,
default=10,
help="Print statistics every N samples",
)
return parser.parse_args()
def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
logger.info(f"Using device: {device}")
# Initialize video tokenizer (only needed if not using precomputed features)
video_tokenizer = None
use_precomputed = args.features_dir is not None
if not use_precomputed:
if args.csv_path is None:
raise ValueError("Either --csv_path or --features_dir must be provided")
logger.info(f"Loading video tokenizer: {args.video_tokenizer_model_id}")
video_tokenizer = CosmosVideoTokenizer(
model_id=args.video_tokenizer_model_id,
device=device,
dtype=dtype
)
video_tokenizer.requires_grad_(False)
video_tokenizer.eval()
# Get tokenizer info
logger.info(f"Video tokenizer codebook_size: {video_tokenizer.codebook_size}")
logger.info(f"Video tokenizer mask_token_id: {video_tokenizer.mask_token_id}")
# Create dataset
if use_precomputed:
logger.info(f"Using precomputed features from: {args.features_dir}")
dataset = PrecomputedFeatureDataset(
features_dir=args.features_dir,
num_samples=args.max_samples,
)
else:
# Auto-detect video_root_dir if not provided
if args.video_root_dir is None:
csv_dir = os.path.dirname(args.csv_path)
if os.path.exists(os.path.join(csv_dir, 'video_reorg')):
video_root_dir = os.path.join(csv_dir, 'video_reorg')
elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')):
video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg')
else:
video_root_dir = csv_dir
logger.warning(f"Video directory not found, using CSV directory: {video_root_dir}")
else:
video_root_dir = args.video_root_dir
# Create tokenizer for dataset
if args.text_encoder_architecture == "umt5-base":
model_id = "google/umt5-base"
elif args.text_encoder_architecture == "umt5-xxl":
model_id = "google/umt5-xxl"
elif args.text_encoder_architecture == "t5":
model_id = "t5-base"
else:
raise ValueError(f"Unknown text encoder: {args.text_encoder_architecture}")
tokenizer = T5Tokenizer.from_pretrained(model_id)
dataset = OpenVid1MDataset(
csv_path=args.csv_path,
video_root_dir=video_root_dir,
tokenizer=tokenizer,
num_frames=args.num_frames,
height=args.video_height,
width=args.video_width,
text_encoder_architecture=args.text_encoder_architecture,
use_random_temporal_crop=False, # Fixed sampling for consistency
use_random_crop=False, # Center crop for consistency
)
if args.max_samples is not None:
dataset.data = dataset.data[:args.max_samples]
logger.info(f"Limited dataset to {len(dataset)} samples")
logger.info(f"Dataset size: {len(dataset)}")
# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0, # Use 0 to avoid multiprocessing issues
pin_memory=False,
)
# Initialize statistics
global_min = None
global_max = None
total_samples = 0
failed_samples = 0
logger.info("Starting to check codebook range...")
logger.info("=" * 80)
with torch.no_grad():
for batch_idx, batch in enumerate(tqdm(dataloader, desc="Checking codes")):
try:
if use_precomputed:
# Use pre-extracted video codes
video_codes = batch["video_codes"] # [B, F', H', W']
if isinstance(video_codes, torch.Tensor):
video_codes = video_codes.long()
else:
video_codes = torch.from_numpy(video_codes).long()
else:
# Encode videos to get codes
videos = batch["video"].to(device, non_blocking=True) # [B, C, F, H, W]
video_codes = video_tokenizer.encode(videos) # [B, F', H', W']
video_codes = video_codes.cpu().long()
# Update statistics
batch_min = video_codes.min().item()
batch_max = video_codes.max().item()
if global_min is None:
global_min = batch_min
global_max = batch_max
else:
global_min = min(global_min, batch_min)
global_max = max(global_max, batch_max)
total_samples += video_codes.shape[0]
# Print statistics periodically
if (batch_idx + 1) % args.check_interval == 0 or batch_idx == 0:
print(f"\n[Sample {total_samples}]")
print(f" Current batch range: [{batch_min}, {batch_max}]")
print(f" Global range so far: [{global_min}, {global_max}]")
print(f" Codebook size (expected): {video_tokenizer.codebook_size if video_tokenizer else 'N/A'}")
if video_tokenizer:
expected_max = video_tokenizer.codebook_size - 1
print(f" Expected max (codebook_size - 1): {expected_max}")
if global_max > expected_max:
print(f" ⚠️ WARNING: Found code {global_max} > expected max {expected_max}!")
if global_min < 0:
print(f" ⚠️ WARNING: Found code {global_min} < 0!")
# Print unique values count for current batch
unique_values = torch.unique(video_codes).tolist()
print(f" Unique values in batch: {len(unique_values)}")
if len(unique_values) <= 20:
print(f" Values: {sorted(unique_values)}")
else:
print(f" Min unique: {min(unique_values)}, Max unique: {max(unique_values)}")
print("-" * 80)
except Exception as e:
failed_samples += args.batch_size
logger.error(f"Failed to process batch {batch_idx}: {e}")
continue
# Final summary
logger.info("=" * 80)
logger.info("FINAL STATISTICS:")
logger.info(f" Total samples processed: {total_samples}")
logger.info(f" Failed samples: {failed_samples}")
logger.info(f" Global min code: {global_min}")
logger.info(f" Global max code: {global_max}")
logger.info(f" Code range: [{global_min}, {global_max}]")
if video_tokenizer:
expected_max = video_tokenizer.codebook_size - 1
logger.info(f" Expected max (codebook_size - 1): {expected_max}")
logger.info(f" Codebook size: {video_tokenizer.codebook_size}")
logger.info(f" Mask token ID: {video_tokenizer.mask_token_id}")
if global_max > expected_max:
logger.warning(f" ⚠️ WARNING: Found code {global_max} > expected max {expected_max}!")
elif global_max == expected_max:
logger.info(f" ✓ Max code matches expected max")
else:
logger.info(f" Note: Max code {global_max} < expected max {expected_max} (some codes may not be used)")
if global_min < 0:
logger.warning(f" ⚠️ WARNING: Found code {global_min} < 0!")
elif global_min == 0:
logger.info(f" ✓ Min code is 0 (as expected)")
else:
logger.info(f" Note: Min code {global_min} > 0 (some codes may not be used)")
logger.info("=" * 80)
if __name__ == "__main__":
main()