File size: 10,850 Bytes
3d1c0e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
#!/usr/bin/env python3
"""
Check codebook range by iterating through videos and extracting codes.
This script loads videos from the dataset, encodes them to get video codes,
and tracks the min/max values to determine the codebook range.
"""
import argparse
import os
import sys
import logging
from tqdm import tqdm
import torch
import numpy as np
sys.path.append(os.getcwd())
from train.dataset_utils import OpenVid1MDataset, PrecomputedFeatureDataset
from src.pipeline_video import CosmosVideoTokenizer
from transformers import T5Tokenizer
from torch.utils.data import DataLoader
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Check codebook range from video dataset")
parser.add_argument(
"--csv_path",
type=str,
default=None,
help="Path to OpenVid1M CSV file (if using raw videos)",
)
parser.add_argument(
"--video_root_dir",
type=str,
default=None,
help="Root directory containing video files",
)
parser.add_argument(
"--features_dir",
type=str,
default=None,
help="Directory containing pre-extracted features (if using precomputed features)",
)
parser.add_argument(
"--video_tokenizer_model_id",
type=str,
default="Cosmos-1.0-Tokenizer-DV8x16x16",
help="HuggingFace model ID for Cosmos video tokenizer",
)
parser.add_argument(
"--num_frames",
type=int,
default=16,
help="Number of frames per video",
)
parser.add_argument(
"--video_height",
type=int,
default=480,
help="Video height",
)
parser.add_argument(
"--video_width",
type=int,
default=848,
help="Video width",
)
parser.add_argument(
"--text_encoder_architecture",
type=str,
default="umt5-base",
choices=["umt5-base", "umt5-xxl", "t5"],
help="Text encoder architecture",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (use 1 for detailed per-sample tracking)",
)
parser.add_argument(
"--max_samples",
type=int,
default=None,
help="Maximum number of samples to check. If None, check all.",
)
parser.add_argument(
"--check_interval",
type=int,
default=10,
help="Print statistics every N samples",
)
return parser.parse_args()
def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
logger.info(f"Using device: {device}")
# Initialize video tokenizer (only needed if not using precomputed features)
video_tokenizer = None
use_precomputed = args.features_dir is not None
if not use_precomputed:
if args.csv_path is None:
raise ValueError("Either --csv_path or --features_dir must be provided")
logger.info(f"Loading video tokenizer: {args.video_tokenizer_model_id}")
video_tokenizer = CosmosVideoTokenizer(
model_id=args.video_tokenizer_model_id,
device=device,
dtype=dtype
)
video_tokenizer.requires_grad_(False)
video_tokenizer.eval()
# Get tokenizer info
logger.info(f"Video tokenizer codebook_size: {video_tokenizer.codebook_size}")
logger.info(f"Video tokenizer mask_token_id: {video_tokenizer.mask_token_id}")
# Create dataset
if use_precomputed:
logger.info(f"Using precomputed features from: {args.features_dir}")
dataset = PrecomputedFeatureDataset(
features_dir=args.features_dir,
num_samples=args.max_samples,
)
else:
# Auto-detect video_root_dir if not provided
if args.video_root_dir is None:
csv_dir = os.path.dirname(args.csv_path)
if os.path.exists(os.path.join(csv_dir, 'video_reorg')):
video_root_dir = os.path.join(csv_dir, 'video_reorg')
elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')):
video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg')
else:
video_root_dir = csv_dir
logger.warning(f"Video directory not found, using CSV directory: {video_root_dir}")
else:
video_root_dir = args.video_root_dir
# Create tokenizer for dataset
if args.text_encoder_architecture == "umt5-base":
model_id = "google/umt5-base"
elif args.text_encoder_architecture == "umt5-xxl":
model_id = "google/umt5-xxl"
elif args.text_encoder_architecture == "t5":
model_id = "t5-base"
else:
raise ValueError(f"Unknown text encoder: {args.text_encoder_architecture}")
tokenizer = T5Tokenizer.from_pretrained(model_id)
dataset = OpenVid1MDataset(
csv_path=args.csv_path,
video_root_dir=video_root_dir,
tokenizer=tokenizer,
num_frames=args.num_frames,
height=args.video_height,
width=args.video_width,
text_encoder_architecture=args.text_encoder_architecture,
use_random_temporal_crop=False, # Fixed sampling for consistency
use_random_crop=False, # Center crop for consistency
)
if args.max_samples is not None:
dataset.data = dataset.data[:args.max_samples]
logger.info(f"Limited dataset to {len(dataset)} samples")
logger.info(f"Dataset size: {len(dataset)}")
# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0, # Use 0 to avoid multiprocessing issues
pin_memory=False,
)
# Initialize statistics
global_min = None
global_max = None
total_samples = 0
failed_samples = 0
logger.info("Starting to check codebook range...")
logger.info("=" * 80)
with torch.no_grad():
for batch_idx, batch in enumerate(tqdm(dataloader, desc="Checking codes")):
try:
if use_precomputed:
# Use pre-extracted video codes
video_codes = batch["video_codes"] # [B, F', H', W']
if isinstance(video_codes, torch.Tensor):
video_codes = video_codes.long()
else:
video_codes = torch.from_numpy(video_codes).long()
else:
# Encode videos to get codes
videos = batch["video"].to(device, non_blocking=True) # [B, C, F, H, W]
video_codes = video_tokenizer.encode(videos) # [B, F', H', W']
video_codes = video_codes.cpu().long()
# Update statistics
batch_min = video_codes.min().item()
batch_max = video_codes.max().item()
if global_min is None:
global_min = batch_min
global_max = batch_max
else:
global_min = min(global_min, batch_min)
global_max = max(global_max, batch_max)
total_samples += video_codes.shape[0]
# Print statistics periodically
if (batch_idx + 1) % args.check_interval == 0 or batch_idx == 0:
print(f"\n[Sample {total_samples}]")
print(f" Current batch range: [{batch_min}, {batch_max}]")
print(f" Global range so far: [{global_min}, {global_max}]")
print(f" Codebook size (expected): {video_tokenizer.codebook_size if video_tokenizer else 'N/A'}")
if video_tokenizer:
expected_max = video_tokenizer.codebook_size - 1
print(f" Expected max (codebook_size - 1): {expected_max}")
if global_max > expected_max:
print(f" ⚠️ WARNING: Found code {global_max} > expected max {expected_max}!")
if global_min < 0:
print(f" ⚠️ WARNING: Found code {global_min} < 0!")
# Print unique values count for current batch
unique_values = torch.unique(video_codes).tolist()
print(f" Unique values in batch: {len(unique_values)}")
if len(unique_values) <= 20:
print(f" Values: {sorted(unique_values)}")
else:
print(f" Min unique: {min(unique_values)}, Max unique: {max(unique_values)}")
print("-" * 80)
except Exception as e:
failed_samples += args.batch_size
logger.error(f"Failed to process batch {batch_idx}: {e}")
continue
# Final summary
logger.info("=" * 80)
logger.info("FINAL STATISTICS:")
logger.info(f" Total samples processed: {total_samples}")
logger.info(f" Failed samples: {failed_samples}")
logger.info(f" Global min code: {global_min}")
logger.info(f" Global max code: {global_max}")
logger.info(f" Code range: [{global_min}, {global_max}]")
if video_tokenizer:
expected_max = video_tokenizer.codebook_size - 1
logger.info(f" Expected max (codebook_size - 1): {expected_max}")
logger.info(f" Codebook size: {video_tokenizer.codebook_size}")
logger.info(f" Mask token ID: {video_tokenizer.mask_token_id}")
if global_max > expected_max:
logger.warning(f" ⚠️ WARNING: Found code {global_max} > expected max {expected_max}!")
elif global_max == expected_max:
logger.info(f" ✓ Max code matches expected max")
else:
logger.info(f" Note: Max code {global_max} < expected max {expected_max} (some codes may not be used)")
if global_min < 0:
logger.warning(f" ⚠️ WARNING: Found code {global_min} < 0!")
elif global_min == 0:
logger.info(f" ✓ Min code is 0 (as expected)")
else:
logger.info(f" Note: Min code {global_min} > 0 (some codes may not be used)")
logger.info("=" * 80)
if __name__ == "__main__":
main()
|