|
|
|
|
|
""" |
|
|
Check codebook range by iterating through videos and extracting codes. |
|
|
|
|
|
This script loads videos from the dataset, encodes them to get video codes, |
|
|
and tracks the min/max values to determine the codebook range. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
import logging |
|
|
from tqdm import tqdm |
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
|
|
from train.dataset_utils import OpenVid1MDataset, PrecomputedFeatureDataset |
|
|
from src.pipeline_video import CosmosVideoTokenizer |
|
|
from transformers import T5Tokenizer |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
|
level=logging.INFO, |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Check codebook range from video dataset") |
|
|
|
|
|
parser.add_argument( |
|
|
"--csv_path", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Path to OpenVid1M CSV file (if using raw videos)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--video_root_dir", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Root directory containing video files", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--features_dir", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Directory containing pre-extracted features (if using precomputed features)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--video_tokenizer_model_id", |
|
|
type=str, |
|
|
default="Cosmos-1.0-Tokenizer-DV8x16x16", |
|
|
help="HuggingFace model ID for Cosmos video tokenizer", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_frames", |
|
|
type=int, |
|
|
default=16, |
|
|
help="Number of frames per video", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--video_height", |
|
|
type=int, |
|
|
default=480, |
|
|
help="Video height", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--video_width", |
|
|
type=int, |
|
|
default=848, |
|
|
help="Video width", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--text_encoder_architecture", |
|
|
type=str, |
|
|
default="umt5-base", |
|
|
choices=["umt5-base", "umt5-xxl", "t5"], |
|
|
help="Text encoder architecture", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch_size", |
|
|
type=int, |
|
|
default=1, |
|
|
help="Batch size (use 1 for detailed per-sample tracking)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_samples", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Maximum number of samples to check. If None, check all.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--check_interval", |
|
|
type=int, |
|
|
default=10, |
|
|
help="Print statistics every N samples", |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
dtype = torch.float32 |
|
|
|
|
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
|
|
|
video_tokenizer = None |
|
|
use_precomputed = args.features_dir is not None |
|
|
|
|
|
if not use_precomputed: |
|
|
if args.csv_path is None: |
|
|
raise ValueError("Either --csv_path or --features_dir must be provided") |
|
|
|
|
|
logger.info(f"Loading video tokenizer: {args.video_tokenizer_model_id}") |
|
|
video_tokenizer = CosmosVideoTokenizer( |
|
|
model_id=args.video_tokenizer_model_id, |
|
|
device=device, |
|
|
dtype=dtype |
|
|
) |
|
|
video_tokenizer.requires_grad_(False) |
|
|
video_tokenizer.eval() |
|
|
|
|
|
|
|
|
logger.info(f"Video tokenizer codebook_size: {video_tokenizer.codebook_size}") |
|
|
logger.info(f"Video tokenizer mask_token_id: {video_tokenizer.mask_token_id}") |
|
|
|
|
|
|
|
|
if use_precomputed: |
|
|
logger.info(f"Using precomputed features from: {args.features_dir}") |
|
|
dataset = PrecomputedFeatureDataset( |
|
|
features_dir=args.features_dir, |
|
|
num_samples=args.max_samples, |
|
|
) |
|
|
else: |
|
|
|
|
|
if args.video_root_dir is None: |
|
|
csv_dir = os.path.dirname(args.csv_path) |
|
|
if os.path.exists(os.path.join(csv_dir, 'video_reorg')): |
|
|
video_root_dir = os.path.join(csv_dir, 'video_reorg') |
|
|
elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')): |
|
|
video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg') |
|
|
else: |
|
|
video_root_dir = csv_dir |
|
|
logger.warning(f"Video directory not found, using CSV directory: {video_root_dir}") |
|
|
else: |
|
|
video_root_dir = args.video_root_dir |
|
|
|
|
|
|
|
|
if args.text_encoder_architecture == "umt5-base": |
|
|
model_id = "google/umt5-base" |
|
|
elif args.text_encoder_architecture == "umt5-xxl": |
|
|
model_id = "google/umt5-xxl" |
|
|
elif args.text_encoder_architecture == "t5": |
|
|
model_id = "t5-base" |
|
|
else: |
|
|
raise ValueError(f"Unknown text encoder: {args.text_encoder_architecture}") |
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained(model_id) |
|
|
|
|
|
dataset = OpenVid1MDataset( |
|
|
csv_path=args.csv_path, |
|
|
video_root_dir=video_root_dir, |
|
|
tokenizer=tokenizer, |
|
|
num_frames=args.num_frames, |
|
|
height=args.video_height, |
|
|
width=args.video_width, |
|
|
text_encoder_architecture=args.text_encoder_architecture, |
|
|
use_random_temporal_crop=False, |
|
|
use_random_crop=False, |
|
|
) |
|
|
|
|
|
if args.max_samples is not None: |
|
|
dataset.data = dataset.data[:args.max_samples] |
|
|
logger.info(f"Limited dataset to {len(dataset)} samples") |
|
|
|
|
|
logger.info(f"Dataset size: {len(dataset)}") |
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=0, |
|
|
pin_memory=False, |
|
|
) |
|
|
|
|
|
|
|
|
global_min = None |
|
|
global_max = None |
|
|
total_samples = 0 |
|
|
failed_samples = 0 |
|
|
|
|
|
logger.info("Starting to check codebook range...") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch_idx, batch in enumerate(tqdm(dataloader, desc="Checking codes")): |
|
|
try: |
|
|
if use_precomputed: |
|
|
|
|
|
video_codes = batch["video_codes"] |
|
|
if isinstance(video_codes, torch.Tensor): |
|
|
video_codes = video_codes.long() |
|
|
else: |
|
|
video_codes = torch.from_numpy(video_codes).long() |
|
|
else: |
|
|
|
|
|
videos = batch["video"].to(device, non_blocking=True) |
|
|
video_codes = video_tokenizer.encode(videos) |
|
|
video_codes = video_codes.cpu().long() |
|
|
|
|
|
|
|
|
batch_min = video_codes.min().item() |
|
|
batch_max = video_codes.max().item() |
|
|
|
|
|
if global_min is None: |
|
|
global_min = batch_min |
|
|
global_max = batch_max |
|
|
else: |
|
|
global_min = min(global_min, batch_min) |
|
|
global_max = max(global_max, batch_max) |
|
|
|
|
|
total_samples += video_codes.shape[0] |
|
|
|
|
|
|
|
|
if (batch_idx + 1) % args.check_interval == 0 or batch_idx == 0: |
|
|
print(f"\n[Sample {total_samples}]") |
|
|
print(f" Current batch range: [{batch_min}, {batch_max}]") |
|
|
print(f" Global range so far: [{global_min}, {global_max}]") |
|
|
print(f" Codebook size (expected): {video_tokenizer.codebook_size if video_tokenizer else 'N/A'}") |
|
|
if video_tokenizer: |
|
|
expected_max = video_tokenizer.codebook_size - 1 |
|
|
print(f" Expected max (codebook_size - 1): {expected_max}") |
|
|
if global_max > expected_max: |
|
|
print(f" ⚠️ WARNING: Found code {global_max} > expected max {expected_max}!") |
|
|
if global_min < 0: |
|
|
print(f" ⚠️ WARNING: Found code {global_min} < 0!") |
|
|
|
|
|
|
|
|
unique_values = torch.unique(video_codes).tolist() |
|
|
print(f" Unique values in batch: {len(unique_values)}") |
|
|
if len(unique_values) <= 20: |
|
|
print(f" Values: {sorted(unique_values)}") |
|
|
else: |
|
|
print(f" Min unique: {min(unique_values)}, Max unique: {max(unique_values)}") |
|
|
print("-" * 80) |
|
|
|
|
|
except Exception as e: |
|
|
failed_samples += args.batch_size |
|
|
logger.error(f"Failed to process batch {batch_idx}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
logger.info("=" * 80) |
|
|
logger.info("FINAL STATISTICS:") |
|
|
logger.info(f" Total samples processed: {total_samples}") |
|
|
logger.info(f" Failed samples: {failed_samples}") |
|
|
logger.info(f" Global min code: {global_min}") |
|
|
logger.info(f" Global max code: {global_max}") |
|
|
logger.info(f" Code range: [{global_min}, {global_max}]") |
|
|
|
|
|
if video_tokenizer: |
|
|
expected_max = video_tokenizer.codebook_size - 1 |
|
|
logger.info(f" Expected max (codebook_size - 1): {expected_max}") |
|
|
logger.info(f" Codebook size: {video_tokenizer.codebook_size}") |
|
|
logger.info(f" Mask token ID: {video_tokenizer.mask_token_id}") |
|
|
|
|
|
if global_max > expected_max: |
|
|
logger.warning(f" ⚠️ WARNING: Found code {global_max} > expected max {expected_max}!") |
|
|
elif global_max == expected_max: |
|
|
logger.info(f" ✓ Max code matches expected max") |
|
|
else: |
|
|
logger.info(f" Note: Max code {global_max} < expected max {expected_max} (some codes may not be used)") |
|
|
|
|
|
if global_min < 0: |
|
|
logger.warning(f" ⚠️ WARNING: Found code {global_min} < 0!") |
|
|
elif global_min == 0: |
|
|
logger.info(f" ✓ Min code is 0 (as expected)") |
|
|
else: |
|
|
logger.info(f" Note: Min code {global_min} > 0 (some codes may not be used)") |
|
|
|
|
|
logger.info("=" * 80) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|