|
|
|
|
|
""" |
|
|
Extract video codes and text embeddings from video-text pairs for efficient training. |
|
|
|
|
|
This script pre-extracts: |
|
|
1. Video codes: Discrete tokens from CosmosVideoTokenizer |
|
|
2. Text embeddings: Encoder hidden states from T5/UMT5 |
|
|
|
|
|
The extracted features are saved to disk and can be loaded directly during training, |
|
|
avoiding repeated encoding operations. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
import torch |
|
|
import numpy as np |
|
|
from torch.utils.data import DataLoader, DistributedSampler |
|
|
import json |
|
|
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
|
|
from train.dataset_utils import OpenVid1MDataset, tokenize_prompt, encode_prompt |
|
|
from src.pipeline_video import CosmosVideoTokenizer |
|
|
from transformers import T5Tokenizer, T5EncoderModel |
|
|
from accelerate import Accelerator |
|
|
|
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
|
level=logging.INFO, |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def get_hierarchical_path(base_dir, index): |
|
|
""" |
|
|
Get hierarchical path for storing features in 3-level directory structure. |
|
|
|
|
|
Structure: base_dir/level1/level2/level3/filename.npy |
|
|
- level1: index // 1000000 (0-999) |
|
|
- level2: (index // 1000) % 1000 (0-999) |
|
|
- level3: index % 1000 (0-999) |
|
|
|
|
|
Args: |
|
|
base_dir: Base directory for features |
|
|
index: Sample index |
|
|
|
|
|
Returns: |
|
|
Full path to the file |
|
|
""" |
|
|
level1 = index // 1000000 |
|
|
level2 = (index // 1000) % 1000 |
|
|
level3 = index % 1000 |
|
|
|
|
|
dir_path = os.path.join( |
|
|
base_dir, |
|
|
f"{level1:03d}", |
|
|
f"{level2:03d}", |
|
|
f"{level3:03d}" |
|
|
) |
|
|
file_path = os.path.join(dir_path, f"{index:08d}.npy") |
|
|
|
|
|
return dir_path, file_path |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Extract video codes and text embeddings") |
|
|
|
|
|
parser.add_argument( |
|
|
"--csv_path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to OpenVid1M CSV file", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--video_root_dir", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Root directory containing video files. If None, will auto-detect.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Output directory to save extracted features", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--text_encoder_architecture", |
|
|
type=str, |
|
|
default="umt5-base", |
|
|
choices=["umt5-base", "umt5-xxl", "t5"], |
|
|
help="Text encoder architecture", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--video_tokenizer_model_id", |
|
|
type=str, |
|
|
default="Cosmos-1.0-Tokenizer-DV8x16x16", |
|
|
help="HuggingFace model ID for Cosmos video tokenizer", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_frames", |
|
|
type=int, |
|
|
default=16, |
|
|
help="Number of frames per video", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--video_height", |
|
|
type=int, |
|
|
default=480, |
|
|
help="Video height", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--video_width", |
|
|
type=int, |
|
|
default=848, |
|
|
help="Video width", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch_size", |
|
|
type=int, |
|
|
default=4, |
|
|
help="Batch size for feature extraction", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_workers", |
|
|
type=int, |
|
|
default=4, |
|
|
help="Number of dataloader workers", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_samples", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Maximum number of samples to process (for testing). If None, process all.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--resume_from_index", |
|
|
type=int, |
|
|
default=0, |
|
|
help="Resume extraction from this index (for resuming interrupted extraction)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt_prefix", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Prefix to add to prompts", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--extract_video", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Extract video codes. Enable this flag to dump video codes.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--extract_text", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Extract text embeddings. Enable this flag to dump text embeddings.", |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
|
|
|
logger.info(f"Process {accelerator.process_index}/{accelerator.num_processes} on device {accelerator.device}") |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
logger.info(f"Output directory: {args.output_dir}") |
|
|
logger.info(f"Using {accelerator.num_processes} GPUs for parallel extraction") |
|
|
logger.info(f"Extract video codes: {args.extract_video}") |
|
|
logger.info(f"Extract text embeddings: {args.extract_text}") |
|
|
|
|
|
|
|
|
if not args.extract_video and not args.extract_text: |
|
|
raise ValueError("At least one feature type must be enabled. Use --extract_video and/or --extract_text.") |
|
|
|
|
|
device = accelerator.device |
|
|
dtype = torch.float32 |
|
|
|
|
|
|
|
|
text_encoder = None |
|
|
tokenizer = None |
|
|
if args.extract_text: |
|
|
logger.info(f"Loading text encoder: {args.text_encoder_architecture}") |
|
|
if args.text_encoder_architecture == "umt5-base": |
|
|
model_id = "google/umt5-base" |
|
|
elif args.text_encoder_architecture == "umt5-xxl": |
|
|
model_id = "google/umt5-xxl" |
|
|
elif args.text_encoder_architecture == "t5": |
|
|
model_id = "t5-base" |
|
|
else: |
|
|
raise ValueError(f"Unknown text encoder: {args.text_encoder_architecture}") |
|
|
|
|
|
text_encoder = T5EncoderModel.from_pretrained(model_id) |
|
|
tokenizer = T5Tokenizer.from_pretrained(model_id) |
|
|
text_encoder.to(device=device, dtype=dtype) |
|
|
text_encoder.eval() |
|
|
text_encoder.requires_grad_(False) |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
logger.info("Extracting empty_embeds for conditional dropout...") |
|
|
with torch.no_grad(): |
|
|
empty_input_ids = tokenize_prompt(tokenizer, "", args.text_encoder_architecture) |
|
|
empty_input_ids = empty_input_ids.to(device) |
|
|
empty_embeds, _ = encode_prompt( |
|
|
text_encoder, |
|
|
empty_input_ids, |
|
|
args.text_encoder_architecture |
|
|
) |
|
|
|
|
|
|
|
|
empty_embeds_np = empty_embeds.cpu().numpy().astype(np.float16) |
|
|
empty_embeds_path = os.path.join(args.output_dir, "empty_embeds.npy") |
|
|
np.save(empty_embeds_path, empty_embeds_np) |
|
|
logger.info(f"Saved empty_embeds to: {empty_embeds_path}") |
|
|
logger.info(f" Shape: {empty_embeds_np.shape}, dtype: {empty_embeds_np.dtype}") |
|
|
else: |
|
|
logger.info("Skipping text encoder loading (--no_extract_text)") |
|
|
|
|
|
|
|
|
|
|
|
if args.extract_video: |
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("google/umt5-base") |
|
|
|
|
|
|
|
|
video_tokenizer = None |
|
|
if args.extract_video: |
|
|
logger.info(f"Loading video tokenizer: {args.video_tokenizer_model_id}") |
|
|
video_tokenizer = CosmosVideoTokenizer( |
|
|
model_id=args.video_tokenizer_model_id, |
|
|
device=device, |
|
|
dtype=dtype |
|
|
) |
|
|
video_tokenizer.requires_grad_(False) |
|
|
video_tokenizer.eval() |
|
|
else: |
|
|
logger.info("Skipping video tokenizer loading (--no_extract_video)") |
|
|
|
|
|
|
|
|
if args.video_root_dir is None: |
|
|
csv_dir = os.path.dirname(args.csv_path) |
|
|
if os.path.exists(os.path.join(csv_dir, 'video_reorg')): |
|
|
video_root_dir = os.path.join(csv_dir, 'video_reorg') |
|
|
elif os.path.exists(os.path.join(os.path.dirname(csv_dir), 'video_reorg')): |
|
|
video_root_dir = os.path.join(os.path.dirname(csv_dir), 'video_reorg') |
|
|
else: |
|
|
video_root_dir = csv_dir |
|
|
logger.warning(f"Video directory not found, using CSV directory: {video_root_dir}") |
|
|
else: |
|
|
video_root_dir = args.video_root_dir |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = OpenVid1MDataset( |
|
|
csv_path=args.csv_path, |
|
|
video_root_dir=video_root_dir, |
|
|
tokenizer=tokenizer, |
|
|
num_frames=args.num_frames, |
|
|
height=args.video_height, |
|
|
width=args.video_width, |
|
|
text_encoder_architecture=args.text_encoder_architecture, |
|
|
prompt_prefix=args.prompt_prefix, |
|
|
use_random_temporal_crop=False, |
|
|
use_random_crop=False, |
|
|
) |
|
|
|
|
|
|
|
|
if args.max_samples is not None: |
|
|
dataset.data = dataset.data[:args.max_samples] |
|
|
logger.info(f"Limited dataset to {len(dataset)} samples") |
|
|
|
|
|
|
|
|
num_processes = accelerator.num_processes |
|
|
process_index = accelerator.process_index |
|
|
|
|
|
|
|
|
if args.resume_from_index > 0: |
|
|
dataset.data = dataset.data[args.resume_from_index:] |
|
|
logger.info(f"Resuming from index {args.resume_from_index}, remaining samples: {len(dataset)}") |
|
|
|
|
|
|
|
|
sampler = DistributedSampler( |
|
|
dataset, |
|
|
num_replicas=num_processes, |
|
|
rank=process_index, |
|
|
shuffle=False, |
|
|
drop_last=False, |
|
|
) |
|
|
|
|
|
|
|
|
sampler_indices = list(sampler) |
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=args.batch_size, |
|
|
sampler=sampler, |
|
|
num_workers=args.num_workers, |
|
|
pin_memory=True, |
|
|
) |
|
|
|
|
|
dataloader = accelerator.prepare(dataloader) |
|
|
|
|
|
|
|
|
video_codes_dir = None |
|
|
text_embeddings_dir = None |
|
|
if args.extract_video: |
|
|
video_codes_dir = os.path.join(args.output_dir, "video_codes") |
|
|
os.makedirs(video_codes_dir, exist_ok=True) |
|
|
if args.extract_text: |
|
|
text_embeddings_dir = os.path.join(args.output_dir, "text_embeddings") |
|
|
os.makedirs(text_embeddings_dir, exist_ok=True) |
|
|
|
|
|
metadata_file = os.path.join(args.output_dir, "metadata.json") |
|
|
|
|
|
|
|
|
total_samples = len(dataset) |
|
|
|
|
|
|
|
|
logger.info(f"[GPU {process_index}] Starting feature extraction for {total_samples} samples (process {process_index+1}/{num_processes})...") |
|
|
logger.info(f"[GPU {process_index}] This process will handle ~{len(dataloader) * args.batch_size} samples") |
|
|
|
|
|
|
|
|
codebook_size = None |
|
|
mask_token_id = None |
|
|
if args.extract_video and video_tokenizer is not None: |
|
|
codebook_size = video_tokenizer.codebook_size |
|
|
mask_token_id = video_tokenizer.mask_token_id |
|
|
logger.info(f"[GPU {process_index}] Video tokenizer info: codebook_size={codebook_size}, mask_token_id={mask_token_id}") |
|
|
|
|
|
|
|
|
empty_embeds_shape = None |
|
|
empty_embeds_path = os.path.join(args.output_dir, "empty_embeds.npy") |
|
|
if args.extract_text and accelerator.is_main_process and os.path.exists(empty_embeds_path): |
|
|
empty_embeds_np = np.load(empty_embeds_path) |
|
|
empty_embeds_shape = list(empty_embeds_np.shape) |
|
|
logger.info(f"Empty embeds shape: {empty_embeds_shape}") |
|
|
|
|
|
|
|
|
process_metadata = { |
|
|
"process_index": process_index, |
|
|
"num_samples": total_samples, |
|
|
"extract_video": args.extract_video, |
|
|
"extract_text": args.extract_text, |
|
|
"text_encoder_architecture": args.text_encoder_architecture if args.extract_text else None, |
|
|
"video_tokenizer_model_id": args.video_tokenizer_model_id if args.extract_video else None, |
|
|
"codebook_size": codebook_size, |
|
|
"mask_token_id": mask_token_id, |
|
|
"num_frames": args.num_frames, |
|
|
"video_height": args.video_height, |
|
|
"video_width": args.video_width, |
|
|
"prompt_prefix": args.prompt_prefix, |
|
|
"empty_embeds_shape": empty_embeds_shape if process_index == 0 else None, |
|
|
"empty_embeds_path": "empty_embeds.npy" if args.extract_text else None, |
|
|
"samples": [] |
|
|
} |
|
|
|
|
|
|
|
|
process_failed_samples = [] |
|
|
process_samples_processed = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"[GPU {process_index}] Extracting", disable=not accelerator.is_main_process)): |
|
|
batch_size = batch["video"].shape[0] if args.extract_video else batch["prompt_input_ids"].shape[0] |
|
|
|
|
|
|
|
|
video_codes = None |
|
|
if args.extract_video: |
|
|
videos = batch["video"].to(device, non_blocking=True) |
|
|
try: |
|
|
video_codes = video_tokenizer.encode(videos) |
|
|
video_codes = video_codes.cpu().numpy() |
|
|
except Exception as e: |
|
|
logger.error(f"[GPU {process_index}] Failed to encode video batch {batch_idx}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
encoder_hidden_states = None |
|
|
if args.extract_text: |
|
|
prompt_input_ids = batch["prompt_input_ids"].to(device, non_blocking=True) |
|
|
try: |
|
|
encoder_hidden_states, _ = encode_prompt( |
|
|
text_encoder, |
|
|
prompt_input_ids, |
|
|
args.text_encoder_architecture |
|
|
) |
|
|
encoder_hidden_states = encoder_hidden_states.cpu().numpy() |
|
|
except Exception as e: |
|
|
logger.error(f"[GPU {process_index}] Failed to encode text batch {batch_idx}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
local_start_idx = batch_idx * args.batch_size |
|
|
|
|
|
|
|
|
for i in range(batch_size): |
|
|
local_idx = local_start_idx + i |
|
|
|
|
|
if local_idx < len(sampler_indices): |
|
|
|
|
|
global_dataset_idx = sampler_indices[local_idx] |
|
|
|
|
|
|
|
|
sample_idx = args.resume_from_index + global_dataset_idx |
|
|
|
|
|
|
|
|
row = None |
|
|
if global_dataset_idx < len(dataset.data): |
|
|
row = dataset.data[global_dataset_idx] |
|
|
|
|
|
|
|
|
video_code = None |
|
|
if args.extract_video and video_codes is not None: |
|
|
video_code_dir, video_code_path = get_hierarchical_path(video_codes_dir, sample_idx) |
|
|
os.makedirs(video_code_dir, exist_ok=True) |
|
|
video_code = video_codes[i] |
|
|
|
|
|
if isinstance(video_code, torch.Tensor): |
|
|
video_code = video_code.cpu().numpy() |
|
|
|
|
|
video_code = video_code.astype(np.int32) |
|
|
np.save(video_code_path, video_code) |
|
|
|
|
|
|
|
|
text_embedding = None |
|
|
if args.extract_text and encoder_hidden_states is not None: |
|
|
text_embedding_dir, text_embedding_path = get_hierarchical_path(text_embeddings_dir, sample_idx) |
|
|
os.makedirs(text_embedding_dir, exist_ok=True) |
|
|
text_embedding = encoder_hidden_states[i] |
|
|
|
|
|
if isinstance(text_embedding, torch.Tensor): |
|
|
text_embedding = text_embedding.cpu().numpy() |
|
|
|
|
|
text_embedding = text_embedding.astype(np.float16) |
|
|
np.save(text_embedding_path, text_embedding) |
|
|
|
|
|
|
|
|
if row is not None: |
|
|
sample_meta = { |
|
|
"index": sample_idx, |
|
|
"video_path": row.get("video", ""), |
|
|
"caption": row.get("caption", ""), |
|
|
} |
|
|
if args.extract_video and video_code is not None: |
|
|
sample_meta["video_code_shape"] = list(video_code.shape) |
|
|
if args.extract_text and text_embedding is not None: |
|
|
sample_meta["text_embedding_shape"] = list(text_embedding.shape) |
|
|
process_metadata["samples"].append(sample_meta) |
|
|
|
|
|
process_samples_processed += 1 |
|
|
|
|
|
|
|
|
if process_samples_processed % 1000 == 0: |
|
|
process_metadata_file = os.path.join(args.output_dir, f"metadata_process_{process_index}.json") |
|
|
process_metadata["num_extracted"] = process_samples_processed |
|
|
process_metadata["failed_samples"] = process_failed_samples |
|
|
with open(process_metadata_file, 'w') as f: |
|
|
json.dump(process_metadata, f, indent=2) |
|
|
logger.info(f"[GPU {process_index}] Progress: {process_samples_processed} samples extracted") |
|
|
|
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
|
|
|
process_metadata_file = os.path.join(args.output_dir, f"metadata_process_{process_index}.json") |
|
|
process_metadata["num_extracted"] = process_samples_processed |
|
|
process_metadata["failed_samples"] = process_failed_samples |
|
|
with open(process_metadata_file, 'w') as f: |
|
|
json.dump(process_metadata, f, indent=2) |
|
|
|
|
|
logger.info(f"[GPU {process_index}] Process complete: {process_samples_processed} samples extracted") |
|
|
|
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
logger.info("Merging metadata from all processes...") |
|
|
|
|
|
|
|
|
all_samples = [] |
|
|
total_extracted = 0 |
|
|
all_failed = [] |
|
|
|
|
|
for proc_idx in range(num_processes): |
|
|
proc_metadata_file = os.path.join(args.output_dir, f"metadata_process_{proc_idx}.json") |
|
|
if os.path.exists(proc_metadata_file): |
|
|
with open(proc_metadata_file, 'r') as f: |
|
|
proc_meta = json.load(f) |
|
|
all_samples.extend(proc_meta.get("samples", [])) |
|
|
total_extracted += proc_meta.get("num_extracted", 0) |
|
|
all_failed.extend(proc_meta.get("failed_samples", [])) |
|
|
|
|
|
|
|
|
all_samples.sort(key=lambda x: x["index"]) |
|
|
|
|
|
|
|
|
codebook_size = None |
|
|
mask_token_id = None |
|
|
empty_embeds_shape = None |
|
|
empty_embeds_path = None |
|
|
for proc_idx in range(num_processes): |
|
|
proc_metadata_file = os.path.join(args.output_dir, f"metadata_process_{proc_idx}.json") |
|
|
if os.path.exists(proc_metadata_file): |
|
|
with open(proc_metadata_file, 'r') as f: |
|
|
proc_meta = json.load(f) |
|
|
if proc_meta.get("codebook_size") is not None: |
|
|
codebook_size = proc_meta.get("codebook_size") |
|
|
if proc_meta.get("mask_token_id") is not None: |
|
|
mask_token_id = proc_meta.get("mask_token_id") |
|
|
if proc_meta.get("empty_embeds_shape") is not None: |
|
|
empty_embeds_shape = proc_meta.get("empty_embeds_shape") |
|
|
if proc_meta.get("empty_embeds_path") is not None: |
|
|
empty_embeds_path = proc_meta.get("empty_embeds_path") |
|
|
if codebook_size is not None and mask_token_id is not None: |
|
|
if not args.extract_text or (empty_embeds_shape is not None and empty_embeds_path is not None): |
|
|
break |
|
|
|
|
|
|
|
|
merged_metadata = { |
|
|
"num_samples": total_samples, |
|
|
"num_extracted": total_extracted, |
|
|
"num_processes": num_processes, |
|
|
"extract_video": args.extract_video, |
|
|
"extract_text": args.extract_text, |
|
|
"text_encoder_architecture": args.text_encoder_architecture if args.extract_text else None, |
|
|
"video_tokenizer_model_id": args.video_tokenizer_model_id if args.extract_video else None, |
|
|
"codebook_size": codebook_size, |
|
|
"mask_token_id": mask_token_id, |
|
|
"num_frames": args.num_frames, |
|
|
"video_height": args.video_height, |
|
|
"video_width": args.video_width, |
|
|
"prompt_prefix": args.prompt_prefix, |
|
|
"empty_embeds_shape": empty_embeds_shape, |
|
|
"empty_embeds_path": empty_embeds_path, |
|
|
"samples": all_samples, |
|
|
"failed_samples": sorted(set(all_failed)), |
|
|
} |
|
|
|
|
|
|
|
|
with open(metadata_file, 'w') as f: |
|
|
json.dump(merged_metadata, f, indent=2) |
|
|
|
|
|
logger.info(f"Feature extraction complete!") |
|
|
logger.info(f" Total samples: {total_samples}") |
|
|
logger.info(f" Extracted: {total_extracted}") |
|
|
logger.info(f" Failed: {len(merged_metadata['failed_samples'])}") |
|
|
if args.extract_video: |
|
|
logger.info(f" Video codes saved to: {video_codes_dir}") |
|
|
if args.extract_text: |
|
|
logger.info(f" Text embeddings saved to: {text_embeddings_dir}") |
|
|
if empty_embeds_path: |
|
|
logger.info(f" Empty embeds saved to: {os.path.join(args.output_dir, empty_embeds_path)}") |
|
|
logger.info(f" Metadata saved to: {metadata_file}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|