| | |
| | """ |
| | Extract video codes and text embeddings from video-text pairs for efficient training. |
| | |
| | Adds resume-by-index_file and robust merge: |
| | - --index_file: run only specified global indices (e.g. missing_process_0.txt) |
| | - --run_tag: write new per-rank metadata as metadata_process_{rank}.{tag}.json |
| | - Merge old possibly-truncated metadata_process_*.json + new tagged ones into metadata.json |
| | """ |
| |
|
| | import argparse |
| | import os |
| | import sys |
| | import logging |
| | from tqdm import tqdm |
| | import torch |
| | import numpy as np |
| | from torch.utils.data import DataLoader, DistributedSampler, Sampler |
| | import json |
| |
|
| | sys.path.append(os.getcwd()) |
| |
|
| | from train.dataset_utils import OpenVid1MDataset |
| | from src.pipeline_video import CosmosVideoTokenizer |
| | from transformers import T5Tokenizer, T5EncoderModel |
| | from accelerate import Accelerator |
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| | datefmt="%m/%d/%Y %H:%M:%S", |
| | level=logging.INFO, |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | |
| | |
| | |
| | def get_hierarchical_path(base_dir, index): |
| | """ |
| | Structure: base_dir/level1/level2/level3/filename.npy |
| | - level1: index // 1000000 (0-999) |
| | - level2: (index // 1000) % 1000 (0-999) |
| | - level3: index % 1000 (0-999) |
| | """ |
| | level1 = index // 1000000 |
| | level2 = (index // 1000) % 1000 |
| | level3 = index % 1000 |
| | dir_path = os.path.join(base_dir, f"{level1:03d}", f"{level2:03d}", f"{level3:03d}") |
| | file_path = os.path.join(dir_path, f"{index:08d}.npy") |
| | return dir_path, file_path |
| |
|
| |
|
| | def atomic_save_npy(path: str, arr: np.ndarray): |
| | os.makedirs(os.path.dirname(path), exist_ok=True) |
| | tmp = path + ".tmp" |
| | with open(tmp, "wb") as f: |
| | np.save(f, arr) |
| | os.replace(tmp, path) |
| |
|
| |
|
| | def atomic_save_json(path: str, obj, indent=2): |
| | os.makedirs(os.path.dirname(path), exist_ok=True) |
| | tmp = path + ".tmp" |
| | with open(tmp, "w") as f: |
| | json.dump(obj, f, indent=indent) |
| | os.replace(tmp, path) |
| |
|
| |
|
| | def safe_mmap_shape(npy_path: str): |
| | try: |
| | arr = np.load(npy_path, mmap_mode="r") |
| | return list(arr.shape) |
| | except Exception: |
| | return None |
| |
|
| |
|
| | def normalize_input_ids(x: torch.Tensor) -> torch.Tensor: |
| | if x.ndim == 3: |
| | if x.shape[1] == 1: |
| | x = x.squeeze(1) |
| | elif x.shape[2] == 1: |
| | x = x.squeeze(2) |
| | else: |
| | raise ValueError(f"Unexpected input_ids shape: {tuple(x.shape)}") |
| | elif x.ndim != 2: |
| | raise ValueError(f"Unexpected input_ids ndim: {x.ndim}, shape={tuple(x.shape)}") |
| | return x |
| |
|
| |
|
| | def get_feature_paths(args, video_codes_dir, text_embeddings_dir, attention_masks_dir, sample_idx: int): |
| | paths = {} |
| | if args.extract_video: |
| | _, vp = get_hierarchical_path(video_codes_dir, sample_idx) |
| | paths["video"] = vp |
| | if args.extract_text: |
| | _, tp = get_hierarchical_path(text_embeddings_dir, sample_idx) |
| | paths["text"] = tp |
| | if args.save_attention_mask: |
| | _, ap = get_hierarchical_path(attention_masks_dir, sample_idx) |
| | paths["mask"] = ap |
| | return paths |
| |
|
| |
|
| | |
| | |
| | |
| | def iter_samples_salvage(meta_path: str): |
| | """ |
| | Read possibly-truncated metadata_process_*.json and salvage complete objects in "samples":[...]. |
| | """ |
| | p = meta_path |
| | if not os.path.exists(p): |
| | return |
| | with open(p, "r", encoding="utf-8-sig") as f: |
| | buf = "" |
| | found = False |
| | while True: |
| | chunk = f.read(1024 * 1024) |
| | if not chunk: |
| | break |
| | buf += chunk |
| | k = buf.find('"samples"') |
| | if k != -1: |
| | b = buf.find("[", k) |
| | if b != -1: |
| | buf = buf[b + 1 :] |
| | found = True |
| | break |
| | if len(buf) > 8 * 1024 * 1024: |
| | buf = buf[-4 * 1024 * 1024 :] |
| | if not found: |
| | return |
| |
|
| | in_string = False |
| | escape = False |
| | depth = 0 |
| | collecting = False |
| | obj = [] |
| |
|
| | def feed(ch): |
| | nonlocal in_string, escape, depth, collecting, obj |
| | if in_string: |
| | if collecting: |
| | obj.append(ch) |
| | if escape: |
| | escape = False |
| | else: |
| | if ch == "\\": |
| | escape = True |
| | elif ch == '"': |
| | in_string = False |
| | return None |
| |
|
| | if ch == '"': |
| | in_string = True |
| | if collecting: |
| | obj.append(ch) |
| | return None |
| |
|
| | if not collecting and ch == "]": |
| | return "__END__" |
| |
|
| | if ch == "{": |
| | if not collecting: |
| | collecting = True |
| | depth = 1 |
| | obj = ["{"] |
| | else: |
| | depth += 1 |
| | obj.append("{") |
| | return None |
| |
|
| | if collecting: |
| | obj.append(ch) |
| | if ch == "{": |
| | depth += 1 |
| | elif ch == "}": |
| | depth -= 1 |
| | if depth == 0: |
| | s = "".join(obj) |
| | collecting = False |
| | obj = [] |
| | try: |
| | return json.loads(s) |
| | except Exception: |
| | return "__BAD__" |
| | return None |
| |
|
| | def consume(text): |
| | for ch in text: |
| | out = feed(ch) |
| | if out == "__END__": |
| | return "__END__" |
| | if isinstance(out, dict): |
| | yield out |
| |
|
| | end = yield from consume(buf) |
| | if end == "__END__": |
| | return |
| |
|
| | while True: |
| | chunk = f.read(1024 * 1024) |
| | if not chunk: |
| | break |
| | end = yield from consume(chunk) |
| | if end == "__END__": |
| | return |
| |
|
| |
|
| | def merge_metadata(output_dir: str, merge_world_size: int, run_tag: str): |
| | """ |
| | Merge: |
| | - old metadata_process_{r}.json (may be truncated -> salvage) |
| | - new metadata_process_{r}.{run_tag}.json (assumed complete) |
| | into metadata.json (index-dedup, prefer new). |
| | """ |
| | outdir = output_dir |
| | by_idx = {} |
| |
|
| | |
| | for r in range(merge_world_size): |
| | old_p = os.path.join(outdir, f"metadata_process_{r}.json") |
| | for s in iter_samples_salvage(old_p) or []: |
| | idx = s.get("index", None) |
| | if idx is None: |
| | continue |
| | by_idx[int(idx)] = s |
| |
|
| | |
| | if run_tag: |
| | for r in range(merge_world_size): |
| | new_p = os.path.join(outdir, f"metadata_process_{r}.{run_tag}.json") |
| | if not os.path.exists(new_p): |
| | continue |
| | try: |
| | with open(new_p, "r") as f: |
| | meta = json.load(f) |
| | for s in meta.get("samples", []): |
| | idx = s.get("index", None) |
| | if idx is None: |
| | continue |
| | by_idx[int(idx)] = s |
| | except Exception as e: |
| | logger.warning(f"Failed to load new meta {new_p}: {e}") |
| |
|
| | samples = [by_idx[k] for k in sorted(by_idx.keys())] |
| |
|
| | merged = { |
| | "num_extracted_metadata": len(samples), |
| | "world_size_used": merge_world_size, |
| | "samples": samples, |
| | } |
| |
|
| | |
| | header_src = None |
| | if run_tag: |
| | p0 = os.path.join(outdir, f"metadata_process_0.{run_tag}.json") |
| | if os.path.exists(p0): |
| | header_src = p0 |
| | if header_src is None: |
| | p0_old = os.path.join(outdir, "metadata_process_0.json") |
| | if os.path.exists(p0_old): |
| | header_src = p0_old |
| |
|
| | if header_src: |
| | |
| | try: |
| | with open(header_src, "r") as f: |
| | m0 = json.load(f) |
| | for k in [ |
| | "extract_video","extract_text","text_encoder_architecture","video_tokenizer_model_id", |
| | "codebook_size","mask_token_id","num_frames","video_height","video_width", |
| | "prompt_prefix","text_dtype","save_attention_mask","empty_embeds_shape","empty_embeds_path", |
| | "num_samples_original","resume_from_index","num_samples_this_run","num_attempted", |
| | "num_extracted","num_failed","num_processes","ranks_seen" |
| | ]: |
| | if k in m0 and m0[k] is not None: |
| | merged[k] = m0[k] |
| | except Exception: |
| | pass |
| |
|
| | metadata_file = os.path.join(outdir, "metadata.json") |
| | |
| | if os.path.exists(metadata_file): |
| | bak = os.path.join(outdir, "metadata.json.bak") |
| | try: |
| | os.replace(metadata_file, bak) |
| | logger.info(f"Backed up old metadata.json -> {bak}") |
| | except Exception: |
| | pass |
| |
|
| | atomic_save_json(metadata_file, merged, indent=2) |
| | logger.info(f"[MERGE] Wrote {metadata_file}, samples={len(samples):,}") |
| |
|
| |
|
| | |
| | |
| | |
| | class IndexListSampler(Sampler): |
| | def __init__(self, indices): |
| | self.indices = list(indices) |
| | def __iter__(self): |
| | return iter(self.indices) |
| | def __len__(self): |
| | return len(self.indices) |
| |
|
| |
|
| | |
| | |
| | |
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Extract video codes and text embeddings") |
| |
|
| | parser.add_argument("--csv_path", type=str, required=True, help="Path to OpenVid1M CSV file") |
| | parser.add_argument("--video_root_dir", type=str, default=None, help="Root directory containing video files. If None, auto-detect.") |
| | parser.add_argument("--output_dir", type=str, required=True, help="Output directory to save extracted features") |
| |
|
| | parser.add_argument( |
| | "--text_encoder_architecture", |
| | type=str, |
| | default="umt5-base", |
| | choices=["umt5-base", "umt5-xxl", "t5"], |
| | help="Text encoder architecture", |
| | ) |
| | parser.add_argument( |
| | "--video_tokenizer_model_id", |
| | type=str, |
| | default="Cosmos-1.0-Tokenizer-DV8x16x16", |
| | help="HuggingFace model ID for Cosmos video tokenizer", |
| | ) |
| | parser.add_argument("--num_frames", type=int, default=16, help="Number of frames per video") |
| | parser.add_argument("--video_height", type=int, default=480, help="Video height") |
| | parser.add_argument("--video_width", type=int, default=848, help="Video width") |
| | parser.add_argument("--batch_size", type=int, default=4, help="Batch size for feature extraction") |
| | parser.add_argument("--num_workers", type=int, default=4, help="Number of dataloader workers") |
| | parser.add_argument("--max_samples", type=int, default=None, help="Max samples (for testing). If None, process all.") |
| | parser.add_argument("--resume_from_index", type=int, default=0, help="Resume extraction from this index") |
| | parser.add_argument("--prompt_prefix", type=str, default=None, help="Prefix to add to prompts") |
| |
|
| | parser.add_argument("--extract_video", action="store_true", default=False, help="Extract video codes") |
| | parser.add_argument("--extract_text", action="store_true", default=False, help="Extract text embeddings") |
| |
|
| | parser.add_argument("--text_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"], help="Text encoder dtype") |
| |
|
| | parser.add_argument("--skip_existing", action="store_true", help="Skip samples whose feature .npy already exist.") |
| | parser.add_argument("--overwrite", action="store_true", help="Overwrite existing .npy (disables skip_existing).") |
| |
|
| | group = parser.add_mutually_exclusive_group() |
| | group.add_argument("--save_attention_mask", dest="save_attention_mask", action="store_true", |
| | help="Save attention mask per sample (default: on).") |
| | group.add_argument("--no_save_attention_mask", dest="save_attention_mask", action="store_false", |
| | help="Do NOT save attention mask per sample.") |
| | parser.set_defaults(save_attention_mask=True) |
| |
|
| | |
| | parser.add_argument( |
| | "--index_file", |
| | type=str, |
| | default=None, |
| | help="Text file with one global sample index per line. " |
| | "Can contain '{rank}' placeholder, e.g. missing_process_{rank}.txt" |
| | ) |
| | parser.add_argument( |
| | "--run_tag", |
| | type=str, |
| | default=None, |
| | help="Tag for new per-rank metadata file: metadata_process_{rank}.{run_tag}.json" |
| | ) |
| | parser.add_argument( |
| | "--merge_world_size", |
| | type=int, |
| | default=None, |
| | help="How many ranks to merge for final metadata.json. " |
| | "Set to 8 for your OpenVid1M run even if you resume with 1 GPU." |
| | ) |
| |
|
| | return parser.parse_args() |
| |
|
| |
|
| | |
| | |
| | |
| | def main(): |
| | args = parse_args() |
| | accelerator = Accelerator() |
| |
|
| | rank = accelerator.process_index |
| | world_size = accelerator.num_processes |
| | logger.info(f"Process {rank}/{world_size} on device {accelerator.device}") |
| |
|
| | if accelerator.is_main_process: |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | logger.info(f"Output directory: {args.output_dir}") |
| | logger.info(f"Using {accelerator.num_processes} GPUs for parallel extraction") |
| | logger.info(f"Extract video codes: {args.extract_video}") |
| | logger.info(f"Extract text embeddings: {args.extract_text}") |
| | logger.info(f"skip_existing={args.skip_existing}, overwrite={args.overwrite}, save_attention_mask={args.save_attention_mask}") |
| | logger.info(f"index_file={args.index_file}, run_tag={args.run_tag}, merge_world_size={args.merge_world_size}") |
| |
|
| | if not args.extract_video and not args.extract_text: |
| | raise ValueError("At least one feature type must be enabled. Use --extract_video and/or --extract_text.") |
| |
|
| | device = accelerator.device |
| | dtype = torch.float32 |
| |
|
| | |
| | text_encoder = None |
| | tokenizer = None |
| | if args.extract_text: |
| | logger.info(f"Loading text encoder: {args.text_encoder_architecture} with dtype {args.text_dtype}") |
| | text_dtype_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} |
| | text_dtype = text_dtype_map[args.text_dtype] |
| |
|
| | if args.text_encoder_architecture == "umt5-base": |
| | model_id = "google/umt5-base" |
| | elif args.text_encoder_architecture == "umt5-xxl": |
| | model_id = "google/umt5-xxl" |
| | elif args.text_encoder_architecture == "t5": |
| | model_id = "t5-base" |
| | else: |
| | raise ValueError(f"Unknown text encoder: {args.text_encoder_architecture}") |
| |
|
| | text_encoder = T5EncoderModel.from_pretrained(model_id, torch_dtype=text_dtype) |
| | tokenizer = T5Tokenizer.from_pretrained(model_id) |
| | text_encoder.to(device=device) |
| | text_encoder.eval() |
| | text_encoder.requires_grad_(False) |
| |
|
| | |
| | if accelerator.is_main_process: |
| | logger.info("Extracting empty_embeds for conditional dropout...") |
| | with torch.no_grad(): |
| | empty = tokenizer("", return_tensors="pt", padding="max_length", max_length=512, truncation=True) |
| | empty_ids = empty["input_ids"].to(device) |
| | empty_mask = empty["attention_mask"].to(device) |
| | outputs = text_encoder(input_ids=empty_ids, attention_mask=empty_mask) |
| | empty_embeds = outputs.last_hidden_state |
| |
|
| | empty_embeds_cpu = empty_embeds.cpu() |
| | if empty_embeds_cpu.dtype == torch.bfloat16: |
| | empty_embeds_cpu = empty_embeds_cpu.to(torch.float32) |
| | empty_embeds_np = empty_embeds_cpu.numpy().astype(np.float16) |
| | empty_embeds_path = os.path.join(args.output_dir, "empty_embeds.npy") |
| | np.save(empty_embeds_path, empty_embeds_np) |
| | logger.info(f"Saved empty_embeds to: {empty_embeds_path} shape={empty_embeds_np.shape} dtype={empty_embeds_np.dtype}") |
| | else: |
| | logger.info("Skipping text encoder loading (extract_text=False)") |
| | if args.extract_video: |
| | tokenizer = T5Tokenizer.from_pretrained("google/umt5-base") |
| |
|
| | |
| | video_tokenizer = None |
| | if args.extract_video: |
| | logger.info(f"Loading video tokenizer: {args.video_tokenizer_model_id}") |
| | video_tokenizer = CosmosVideoTokenizer(model_id=args.video_tokenizer_model_id, device=device, dtype=dtype) |
| | video_tokenizer.requires_grad_(False) |
| | video_tokenizer.eval() |
| | else: |
| | logger.info("Skipping video tokenizer loading (extract_video=False)") |
| |
|
| | |
| | if args.video_root_dir is None: |
| | csv_dir = os.path.dirname(args.csv_path) |
| | if os.path.exists(os.path.join(csv_dir, "video_reorg")): |
| | video_root_dir = os.path.join(csv_dir, "video_reorg") |
| | elif os.path.exists(os.path.join(os.path.dirname(csv_dir), "video_reorg")): |
| | video_root_dir = os.path.join(os.path.dirname(csv_dir), "video_reorg") |
| | else: |
| | video_root_dir = csv_dir |
| | logger.warning(f"Video directory not found, using CSV directory: {video_root_dir}") |
| | else: |
| | video_root_dir = args.video_root_dir |
| |
|
| | |
| | dataset = OpenVid1MDataset( |
| | csv_path=args.csv_path, |
| | video_root_dir=video_root_dir, |
| | tokenizer=tokenizer, |
| | num_frames=args.num_frames, |
| | height=args.video_height, |
| | width=args.video_width, |
| | text_encoder_architecture=args.text_encoder_architecture, |
| | prompt_prefix=args.prompt_prefix, |
| | use_random_temporal_crop=False, |
| | use_random_crop=False, |
| | ) |
| |
|
| | if args.max_samples is not None: |
| | dataset.data = dataset.data[:args.max_samples] |
| | logger.info(f"Limited dataset to {len(dataset)} samples") |
| |
|
| | if args.resume_from_index > 0: |
| | dataset.data = dataset.data[args.resume_from_index:] |
| | logger.info(f"Resuming from index {args.resume_from_index}, remaining samples: {len(dataset)}") |
| |
|
| | num_processes = accelerator.num_processes |
| | process_index = accelerator.process_index |
| |
|
| | |
| | if args.index_file is not None: |
| | idx_path = args.index_file.format(rank=process_index) |
| | with open(idx_path, "r") as f: |
| | wanted_sample_idx = [int(x.strip()) for x in f if x.strip() and not x.strip().startswith("#")] |
| |
|
| | sampler_indices = [] |
| | for sample_idx in wanted_sample_idx: |
| | global_dataset_idx = sample_idx - args.resume_from_index |
| | if 0 <= global_dataset_idx < len(dataset.data): |
| | sampler_indices.append(global_dataset_idx) |
| |
|
| | sampler = IndexListSampler(sampler_indices) |
| | logger.info(f"[GPU {process_index}] Using index_file={idx_path}, indices={len(sampler_indices)}") |
| | else: |
| | sampler = DistributedSampler( |
| | dataset, |
| | num_replicas=num_processes, |
| | rank=process_index, |
| | shuffle=False, |
| | drop_last=False, |
| | ) |
| | sampler_indices = list(sampler) |
| |
|
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | sampler=sampler, |
| | num_workers=args.num_workers, |
| | pin_memory=True, |
| | ) |
| |
|
| | |
| | video_codes_dir = None |
| | text_embeddings_dir = None |
| | attention_masks_dir = None |
| | if args.extract_video: |
| | video_codes_dir = os.path.join(args.output_dir, "video_codes") |
| | os.makedirs(video_codes_dir, exist_ok=True) |
| | if args.extract_text: |
| | text_embeddings_dir = os.path.join(args.output_dir, "text_embeddings") |
| | os.makedirs(text_embeddings_dir, exist_ok=True) |
| | if args.save_attention_mask: |
| | attention_masks_dir = os.path.join(args.output_dir, "attention_masks") |
| | os.makedirs(attention_masks_dir, exist_ok=True) |
| |
|
| | |
| | metadata_file = os.path.join(args.output_dir, "metadata.json") |
| | existing_shapes = {} |
| | if os.path.exists(metadata_file): |
| | try: |
| | with open(metadata_file, "r") as f: |
| | existing_meta = json.load(f) |
| | for sample in existing_meta.get("samples", []): |
| | idx = sample.get("index") |
| | if idx is None: |
| | continue |
| | existing_shapes[int(idx)] = { |
| | "video_code_shape": sample.get("video_code_shape"), |
| | "text_embedding_shape": sample.get("text_embedding_shape"), |
| | "context_len": sample.get("context_len"), |
| | } |
| | logger.info(f"[GPU {process_index}] Loaded existing metadata for {len(existing_shapes)} samples") |
| | except Exception as e: |
| | logger.warning(f"[GPU {process_index}] Failed to load existing metadata: {e}") |
| |
|
| | total_samples = len(dataset) |
| | logger.info(f"[GPU {process_index}] Starting feature extraction for {total_samples} samples " |
| | f"(process {process_index+1}/{num_processes}), assigned={len(sampler_indices)}") |
| |
|
| | |
| | codebook_size = None |
| | mask_token_id = None |
| | if args.extract_video and video_tokenizer is not None: |
| | codebook_size = getattr(video_tokenizer, "codebook_size", None) |
| | mask_token_id = getattr(video_tokenizer, "mask_token_id", None) |
| | logger.info(f"[GPU {process_index}] Video tokenizer: codebook_size={codebook_size}, mask_token_id={mask_token_id}") |
| |
|
| | |
| | empty_embeds_shape = None |
| | empty_embeds_path = os.path.join(args.output_dir, "empty_embeds.npy") |
| | if args.extract_text and accelerator.is_main_process and os.path.exists(empty_embeds_path): |
| | try: |
| | empty_embeds_np = np.load(empty_embeds_path, mmap_mode="r") |
| | empty_embeds_shape = list(empty_embeds_np.shape) |
| | logger.info(f"Empty embeds shape: {empty_embeds_shape}") |
| | except Exception: |
| | pass |
| |
|
| | process_metadata = { |
| | "process_index": process_index, |
| | "num_samples_this_run": total_samples, |
| | "world_size_used": world_size, |
| | "rank_used": rank, |
| | "extract_video": args.extract_video, |
| | "extract_text": args.extract_text, |
| | "text_encoder_architecture": args.text_encoder_architecture if args.extract_text else None, |
| | "video_tokenizer_model_id": args.video_tokenizer_model_id if args.extract_video else None, |
| | "codebook_size": codebook_size, |
| | "mask_token_id": mask_token_id, |
| | "num_frames": args.num_frames, |
| | "video_height": args.video_height, |
| | "video_width": args.video_width, |
| | "prompt_prefix": args.prompt_prefix, |
| | "text_dtype": args.text_dtype if args.extract_text else None, |
| | "save_attention_mask": args.save_attention_mask, |
| | "empty_embeds_shape": empty_embeds_shape if process_index == 0 else None, |
| | "empty_embeds_path": "empty_embeds.npy" if args.extract_text else None, |
| | "samples": [], |
| | } |
| |
|
| | process_failed_samples = [] |
| | process_samples_processed = 0 |
| | process_attempted_samples = 0 |
| |
|
| | with torch.no_grad(): |
| | for batch_idx, batch in enumerate( |
| | tqdm(dataloader, desc=f"[GPU {process_index}] Extracting", disable=not accelerator.is_main_process) |
| | ): |
| | batch_size = batch["video"].shape[0] if args.extract_video else batch["prompt_input_ids"].shape[0] |
| | local_start_idx = batch_idx * args.batch_size |
| |
|
| | |
| | batch_sample_indices = [] |
| | for i in range(batch_size): |
| | local_idx = local_start_idx + i |
| | if local_idx < len(sampler_indices): |
| | global_dataset_idx = sampler_indices[local_idx] |
| | sample_idx = args.resume_from_index + global_dataset_idx |
| | batch_sample_indices.append(sample_idx) |
| | else: |
| | batch_sample_indices.append(None) |
| |
|
| | |
| | batch_paths = [] |
| | for sidx in batch_sample_indices: |
| | if sidx is None: |
| | batch_paths.append(None) |
| | else: |
| | batch_paths.append(get_feature_paths(args, video_codes_dir, text_embeddings_dir, attention_masks_dir, sidx)) |
| |
|
| | |
| | need_video = [False] * batch_size |
| | need_text = [False] * batch_size |
| | need_mask = [False] * batch_size |
| |
|
| | for i, (sidx, paths) in enumerate(zip(batch_sample_indices, batch_paths)): |
| | if sidx is None: |
| | continue |
| |
|
| | if args.overwrite: |
| | need_video[i] = args.extract_video |
| | need_text[i] = args.extract_text |
| | need_mask[i] = args.extract_text and args.save_attention_mask |
| | elif args.skip_existing: |
| | if args.extract_video: |
| | need_video[i] = not os.path.exists(paths["video"]) |
| | if args.extract_text: |
| | need_text[i] = not os.path.exists(paths["text"]) |
| | if args.save_attention_mask: |
| | need_mask[i] = not os.path.exists(paths["mask"]) |
| | else: |
| | need_video[i] = args.extract_video |
| | need_text[i] = args.extract_text |
| | need_mask[i] = args.extract_text and args.save_attention_mask |
| |
|
| | need_any = [v or t or m for v, t, m in zip(need_video, need_text, need_mask)] |
| |
|
| | |
| | |
| | |
| | if (not any(need_any)) and (args.index_file is None): |
| | continue |
| |
|
| | process_attempted_samples += sum(need_any) |
| |
|
| | if batch_idx == 0: |
| | preview = [x for x in batch_sample_indices[:5] if x is not None] |
| | logger.info(f"[GPU {process_index}] First batch sample indices: {preview}") |
| | logger.info(f"[GPU {process_index}] Need any in first batch: {sum(need_any)}/{len(need_any)}") |
| |
|
| | |
| | video_codes = None |
| | need_video_idx = [i for i, ok in enumerate(need_video) if ok] |
| | map_video_pos = {i: p for p, i in enumerate(need_video_idx)} |
| | if args.extract_video and len(need_video_idx) > 0: |
| | videos = batch["video"].to(device, non_blocking=True) |
| | videos_sel = videos[need_video_idx] |
| | try: |
| | vc_sel = video_tokenizer.encode(videos_sel) |
| | video_codes = vc_sel.detach().cpu().numpy() |
| | except Exception as e: |
| | logger.error(f"[GPU {process_index}] Failed to encode video batch {batch_idx}: {e}") |
| | for i in need_video_idx: |
| | sidx = batch_sample_indices[i] |
| | if sidx is not None: |
| | process_failed_samples.append({"index": sidx, "reason": "video_encoding_failed"}) |
| | continue |
| |
|
| | |
| | encoder_hidden_states = None |
| | attention_masks = None |
| | context_lens_np_full = None |
| |
|
| | if args.extract_text: |
| | prompt_input_ids = batch["prompt_input_ids"].to(device, non_blocking=True) |
| | if isinstance(prompt_input_ids, (tuple, list)): |
| | prompt_input_ids = prompt_input_ids[0] |
| | prompt_input_ids = normalize_input_ids(prompt_input_ids).long() |
| |
|
| | pad_id = tokenizer.pad_token_id |
| | attention_mask_full = (prompt_input_ids != pad_id).long() |
| | context_lens_full = attention_mask_full.sum(dim=-1) |
| | context_lens_np_full = context_lens_full.detach().cpu().numpy().astype(np.int32) |
| |
|
| | need_text_idx = [i for i, ok in enumerate(need_text) if ok] |
| | map_text_pos = {i: p for p, i in enumerate(need_text_idx)} |
| |
|
| | if len(need_text_idx) > 0: |
| | ids_sel = prompt_input_ids[need_text_idx] |
| | mask_sel = attention_mask_full[need_text_idx] |
| | try: |
| | outputs = text_encoder(input_ids=ids_sel, attention_mask=mask_sel) |
| | enc = outputs.last_hidden_state.detach().cpu() |
| | if enc.dtype == torch.bfloat16: |
| | enc = enc.to(torch.float32) |
| | encoder_hidden_states = enc.numpy().astype(np.float16) |
| | if args.save_attention_mask: |
| | attention_masks = mask_sel.detach().cpu().numpy().astype(np.int32) |
| | except Exception as e: |
| | logger.error(f"[GPU {process_index}] Failed to encode text batch {batch_idx}: {e}") |
| | for i in need_text_idx: |
| | sidx = batch_sample_indices[i] |
| | if sidx is not None: |
| | process_failed_samples.append({"index": sidx, "reason": "text_encoding_failed"}) |
| | continue |
| | else: |
| | map_text_pos = {} |
| | else: |
| | need_text_idx = [] |
| | map_text_pos = {} |
| |
|
| | |
| | for i in range(batch_size): |
| | local_idx = local_start_idx + i |
| | if local_idx >= len(sampler_indices): |
| | continue |
| |
|
| | global_dataset_idx = sampler_indices[local_idx] |
| | sample_idx = args.resume_from_index + global_dataset_idx |
| | paths = batch_paths[i] |
| |
|
| | row = dataset.data[global_dataset_idx] if global_dataset_idx < len(dataset.data) else None |
| | if row is None: |
| | continue |
| |
|
| | existing_info = existing_shapes.get(sample_idx, {}) |
| |
|
| | video_code_shape = None |
| | if args.extract_video: |
| | if need_video[i]: |
| | if video_codes is None: |
| | process_failed_samples.append({"index": sample_idx, "reason": "video_codes_none"}) |
| | continue |
| | pos = map_video_pos[i] |
| | video_code = video_codes[pos].astype(np.int32) |
| | try: |
| | atomic_save_npy(paths["video"], video_code) |
| | video_code_shape = list(video_code.shape) |
| | except Exception as e: |
| | process_failed_samples.append({"index": sample_idx, "reason": f"video_save_failed: {str(e)}"}) |
| | continue |
| | else: |
| | video_code_shape = existing_info.get("video_code_shape") or safe_mmap_shape(paths["video"]) |
| |
|
| | text_embedding_shape = None |
| | if args.extract_text: |
| | if need_text[i]: |
| | if encoder_hidden_states is None: |
| | process_failed_samples.append({"index": sample_idx, "reason": "text_embeddings_none"}) |
| | continue |
| | pos = map_text_pos[i] |
| | text_emb = encoder_hidden_states[pos] |
| | try: |
| | atomic_save_npy(paths["text"], text_emb) |
| | text_embedding_shape = list(text_emb.shape) |
| | except Exception as e: |
| | process_failed_samples.append({"index": sample_idx, "reason": f"text_save_failed: {str(e)}"}) |
| | continue |
| |
|
| | if args.save_attention_mask and need_mask[i]: |
| | if attention_masks is None: |
| | process_failed_samples.append({"index": sample_idx, "reason": "attention_masks_none"}) |
| | continue |
| | try: |
| | atomic_save_npy(paths["mask"], attention_masks[pos]) |
| | except Exception as e: |
| | process_failed_samples.append({"index": sample_idx, "reason": f"mask_save_failed: {str(e)}"}) |
| | continue |
| | else: |
| | text_embedding_shape = existing_info.get("text_embedding_shape") or safe_mmap_shape(paths["text"]) |
| |
|
| | sample_meta = { |
| | "index": sample_idx, |
| | "video_path": row.get("video", ""), |
| | "caption": row.get("caption", ""), |
| | } |
| | if args.extract_video and video_code_shape is not None: |
| | sample_meta["video_code_shape"] = video_code_shape |
| | if args.extract_text and text_embedding_shape is not None: |
| | sample_meta["text_embedding_shape"] = text_embedding_shape |
| | if args.extract_text and context_lens_np_full is not None: |
| | sample_meta["context_len"] = int(context_lens_np_full[i]) |
| |
|
| | process_metadata["samples"].append(sample_meta) |
| | process_samples_processed += 1 |
| |
|
| | |
| | if process_samples_processed > 0 and (process_samples_processed % 1000 == 0): |
| | suffix = f".{args.run_tag}" if args.run_tag else "" |
| | process_metadata_file = os.path.join(args.output_dir, f"metadata_process_{process_index}{suffix}.json") |
| | process_metadata["num_extracted"] = process_samples_processed |
| | process_metadata["failed_samples"] = process_failed_samples |
| | atomic_save_json(process_metadata_file, process_metadata, indent=2) |
| | logger.info(f"[GPU {process_index}] Progress: {process_samples_processed} samples recorded -> {process_metadata_file}") |
| |
|
| | accelerator.wait_for_everyone() |
| |
|
| | |
| | suffix = f".{args.run_tag}" if args.run_tag else "" |
| | process_metadata_file = os.path.join(args.output_dir, f"metadata_process_{process_index}{suffix}.json") |
| | process_metadata["num_attempted"] = int(process_attempted_samples) |
| | process_metadata["num_extracted"] = int(process_samples_processed) |
| | process_metadata["num_failed"] = int(len(process_failed_samples)) |
| | process_metadata["failed_samples"] = process_failed_samples |
| | atomic_save_json(process_metadata_file, process_metadata, indent=2) |
| |
|
| | logger.info(f"[GPU {process_index}] Done: attempted={process_attempted_samples}, extracted(meta)={process_samples_processed}, failed={len(process_failed_samples)}") |
| | accelerator.wait_for_everyone() |
| |
|
| | |
| | if accelerator.is_main_process: |
| | merge_world = args.merge_world_size if args.merge_world_size is not None else world_size |
| | logger.info(f"[MERGE] merging world_size={merge_world} (run_tag={args.run_tag})") |
| | merge_metadata(args.output_dir, merge_world_size=merge_world, run_tag=args.run_tag) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|