#!/usr/bin/env python3 """ Extract video codes and text embeddings from video-text pairs for efficient training. Adds resume-by-index_file and robust merge: - --index_file: run only specified global indices (e.g. missing_process_0.txt) - --run_tag: write new per-rank metadata as metadata_process_{rank}.{tag}.json - Merge old possibly-truncated metadata_process_*.json + new tagged ones into metadata.json """ import argparse import os import sys import logging from tqdm import tqdm import torch import numpy as np from torch.utils.data import DataLoader, DistributedSampler, Sampler import json sys.path.append(os.getcwd()) from train.dataset_utils import OpenVid1MDataset from src.pipeline_video import CosmosVideoTokenizer from transformers import T5Tokenizer, T5EncoderModel from accelerate import Accelerator logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) # ----------------------------- # IO helpers # ----------------------------- def get_hierarchical_path(base_dir, index): """ Structure: base_dir/level1/level2/level3/filename.npy - level1: index // 1000000 (0-999) - level2: (index // 1000) % 1000 (0-999) - level3: index % 1000 (0-999) """ level1 = index // 1000000 level2 = (index // 1000) % 1000 level3 = index % 1000 dir_path = os.path.join(base_dir, f"{level1:03d}", f"{level2:03d}", f"{level3:03d}") file_path = os.path.join(dir_path, f"{index:08d}.npy") return dir_path, file_path def atomic_save_npy(path: str, arr: np.ndarray): os.makedirs(os.path.dirname(path), exist_ok=True) tmp = path + ".tmp" with open(tmp, "wb") as f: np.save(f, arr) os.replace(tmp, path) def atomic_save_json(path: str, obj, indent=2): os.makedirs(os.path.dirname(path), exist_ok=True) tmp = path + ".tmp" with open(tmp, "w") as f: json.dump(obj, f, indent=indent) os.replace(tmp, path) def safe_mmap_shape(npy_path: str): try: arr = np.load(npy_path, mmap_mode="r") return list(arr.shape) except Exception: return None def normalize_input_ids(x: torch.Tensor) -> torch.Tensor: if x.ndim == 3: if x.shape[1] == 1: x = x.squeeze(1) elif x.shape[2] == 1: x = x.squeeze(2) else: raise ValueError(f"Unexpected input_ids shape: {tuple(x.shape)}") elif x.ndim != 2: raise ValueError(f"Unexpected input_ids ndim: {x.ndim}, shape={tuple(x.shape)}") return x def get_feature_paths(args, video_codes_dir, text_embeddings_dir, attention_masks_dir, sample_idx: int): paths = {} if args.extract_video: _, vp = get_hierarchical_path(video_codes_dir, sample_idx) paths["video"] = vp if args.extract_text: _, tp = get_hierarchical_path(text_embeddings_dir, sample_idx) paths["text"] = tp if args.save_attention_mask: _, ap = get_hierarchical_path(attention_masks_dir, sample_idx) paths["mask"] = ap return paths # ----------------------------- # robust salvage for truncated JSON # ----------------------------- def iter_samples_salvage(meta_path: str): """ Read possibly-truncated metadata_process_*.json and salvage complete objects in "samples":[...]. """ p = meta_path if not os.path.exists(p): return with open(p, "r", encoding="utf-8-sig") as f: buf = "" found = False while True: chunk = f.read(1024 * 1024) if not chunk: break buf += chunk k = buf.find('"samples"') if k != -1: b = buf.find("[", k) if b != -1: buf = buf[b + 1 :] found = True break if len(buf) > 8 * 1024 * 1024: buf = buf[-4 * 1024 * 1024 :] if not found: return in_string = False escape = False depth = 0 collecting = False obj = [] def feed(ch): nonlocal in_string, escape, depth, collecting, obj if in_string: if collecting: obj.append(ch) if escape: escape = False else: if ch == "\\": escape = True elif ch == '"': in_string = False return None if ch == '"': in_string = True if collecting: obj.append(ch) return None if not collecting and ch == "]": return "__END__" if ch == "{": if not collecting: collecting = True depth = 1 obj = ["{"] else: depth += 1 obj.append("{") return None if collecting: obj.append(ch) if ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: s = "".join(obj) collecting = False obj = [] try: return json.loads(s) except Exception: return "__BAD__" return None def consume(text): for ch in text: out = feed(ch) if out == "__END__": return "__END__" if isinstance(out, dict): yield out end = yield from consume(buf) if end == "__END__": return while True: chunk = f.read(1024 * 1024) if not chunk: break end = yield from consume(chunk) if end == "__END__": return def merge_metadata(output_dir: str, merge_world_size: int, run_tag: str): """ Merge: - old metadata_process_{r}.json (may be truncated -> salvage) - new metadata_process_{r}.{run_tag}.json (assumed complete) into metadata.json (index-dedup, prefer new). """ outdir = output_dir by_idx = {} # old first for r in range(merge_world_size): old_p = os.path.join(outdir, f"metadata_process_{r}.json") for s in iter_samples_salvage(old_p) or []: idx = s.get("index", None) if idx is None: continue by_idx[int(idx)] = s # then new (override) if run_tag: for r in range(merge_world_size): new_p = os.path.join(outdir, f"metadata_process_{r}.{run_tag}.json") if not os.path.exists(new_p): continue try: with open(new_p, "r") as f: meta = json.load(f) for s in meta.get("samples", []): idx = s.get("index", None) if idx is None: continue by_idx[int(idx)] = s except Exception as e: logger.warning(f"Failed to load new meta {new_p}: {e}") samples = [by_idx[k] for k in sorted(by_idx.keys())] merged = { "num_extracted_metadata": len(samples), "world_size_used": merge_world_size, "samples": samples, } # try pull header info from any available new meta (rank0 preferred) header_src = None if run_tag: p0 = os.path.join(outdir, f"metadata_process_0.{run_tag}.json") if os.path.exists(p0): header_src = p0 if header_src is None: p0_old = os.path.join(outdir, "metadata_process_0.json") if os.path.exists(p0_old): header_src = p0_old if header_src: # best-effort load; if truncated, skip try: with open(header_src, "r") as f: m0 = json.load(f) for k in [ "extract_video","extract_text","text_encoder_architecture","video_tokenizer_model_id", "codebook_size","mask_token_id","num_frames","video_height","video_width", "prompt_prefix","text_dtype","save_attention_mask","empty_embeds_shape","empty_embeds_path", "num_samples_original","resume_from_index","num_samples_this_run","num_attempted", "num_extracted","num_failed","num_processes","ranks_seen" ]: if k in m0 and m0[k] is not None: merged[k] = m0[k] except Exception: pass metadata_file = os.path.join(outdir, "metadata.json") # backup old if os.path.exists(metadata_file): bak = os.path.join(outdir, "metadata.json.bak") try: os.replace(metadata_file, bak) logger.info(f"Backed up old metadata.json -> {bak}") except Exception: pass atomic_save_json(metadata_file, merged, indent=2) logger.info(f"[MERGE] Wrote {metadata_file}, samples={len(samples):,}") # ----------------------------- # Sampler for index list # ----------------------------- class IndexListSampler(Sampler): def __init__(self, indices): self.indices = list(indices) def __iter__(self): return iter(self.indices) def __len__(self): return len(self.indices) # ----------------------------- # Args # ----------------------------- def parse_args(): parser = argparse.ArgumentParser(description="Extract video codes and text embeddings") parser.add_argument("--csv_path", type=str, required=True, help="Path to OpenVid1M CSV file") parser.add_argument("--video_root_dir", type=str, default=None, help="Root directory containing video files. If None, auto-detect.") parser.add_argument("--output_dir", type=str, required=True, help="Output directory to save extracted features") parser.add_argument( "--text_encoder_architecture", type=str, default="umt5-base", choices=["umt5-base", "umt5-xxl", "t5"], help="Text encoder architecture", ) parser.add_argument( "--video_tokenizer_model_id", type=str, default="Cosmos-1.0-Tokenizer-DV8x16x16", help="HuggingFace model ID for Cosmos video tokenizer", ) parser.add_argument("--num_frames", type=int, default=16, help="Number of frames per video") parser.add_argument("--video_height", type=int, default=480, help="Video height") parser.add_argument("--video_width", type=int, default=848, help="Video width") parser.add_argument("--batch_size", type=int, default=4, help="Batch size for feature extraction") parser.add_argument("--num_workers", type=int, default=4, help="Number of dataloader workers") parser.add_argument("--max_samples", type=int, default=None, help="Max samples (for testing). If None, process all.") parser.add_argument("--resume_from_index", type=int, default=0, help="Resume extraction from this index") parser.add_argument("--prompt_prefix", type=str, default=None, help="Prefix to add to prompts") parser.add_argument("--extract_video", action="store_true", default=False, help="Extract video codes") parser.add_argument("--extract_text", action="store_true", default=False, help="Extract text embeddings") parser.add_argument("--text_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"], help="Text encoder dtype") parser.add_argument("--skip_existing", action="store_true", help="Skip samples whose feature .npy already exist.") parser.add_argument("--overwrite", action="store_true", help="Overwrite existing .npy (disables skip_existing).") group = parser.add_mutually_exclusive_group() group.add_argument("--save_attention_mask", dest="save_attention_mask", action="store_true", help="Save attention mask per sample (default: on).") group.add_argument("--no_save_attention_mask", dest="save_attention_mask", action="store_false", help="Do NOT save attention mask per sample.") parser.set_defaults(save_attention_mask=True) # resume / merge additions parser.add_argument( "--index_file", type=str, default=None, help="Text file with one global sample index per line. " "Can contain '{rank}' placeholder, e.g. missing_process_{rank}.txt" ) parser.add_argument( "--run_tag", type=str, default=None, help="Tag for new per-rank metadata file: metadata_process_{rank}.{run_tag}.json" ) parser.add_argument( "--merge_world_size", type=int, default=None, help="How many ranks to merge for final metadata.json. " "Set to 8 for your OpenVid1M run even if you resume with 1 GPU." ) return parser.parse_args() # ----------------------------- # Main # ----------------------------- def main(): args = parse_args() accelerator = Accelerator() rank = accelerator.process_index world_size = accelerator.num_processes logger.info(f"Process {rank}/{world_size} on device {accelerator.device}") if accelerator.is_main_process: os.makedirs(args.output_dir, exist_ok=True) logger.info(f"Output directory: {args.output_dir}") logger.info(f"Using {accelerator.num_processes} GPUs for parallel extraction") logger.info(f"Extract video codes: {args.extract_video}") logger.info(f"Extract text embeddings: {args.extract_text}") logger.info(f"skip_existing={args.skip_existing}, overwrite={args.overwrite}, save_attention_mask={args.save_attention_mask}") logger.info(f"index_file={args.index_file}, run_tag={args.run_tag}, merge_world_size={args.merge_world_size}") if not args.extract_video and not args.extract_text: raise ValueError("At least one feature type must be enabled. Use --extract_video and/or --extract_text.") device = accelerator.device dtype = torch.float32 # ---- text encoder/tokenizer text_encoder = None tokenizer = None if args.extract_text: logger.info(f"Loading text encoder: {args.text_encoder_architecture} with dtype {args.text_dtype}") text_dtype_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} text_dtype = text_dtype_map[args.text_dtype] if args.text_encoder_architecture == "umt5-base": model_id = "google/umt5-base" elif args.text_encoder_architecture == "umt5-xxl": model_id = "google/umt5-xxl" elif args.text_encoder_architecture == "t5": model_id = "t5-base" else: raise ValueError(f"Unknown text encoder: {args.text_encoder_architecture}") text_encoder = T5EncoderModel.from_pretrained(model_id, torch_dtype=text_dtype) tokenizer = T5Tokenizer.from_pretrained(model_id) text_encoder.to(device=device) text_encoder.eval() text_encoder.requires_grad_(False) # empty_embeds on main process if accelerator.is_main_process: logger.info("Extracting empty_embeds for conditional dropout...") with torch.no_grad(): empty = tokenizer("", return_tensors="pt", padding="max_length", max_length=512, truncation=True) empty_ids = empty["input_ids"].to(device) empty_mask = empty["attention_mask"].to(device) outputs = text_encoder(input_ids=empty_ids, attention_mask=empty_mask) empty_embeds = outputs.last_hidden_state # [1, 512, D] empty_embeds_cpu = empty_embeds.cpu() if empty_embeds_cpu.dtype == torch.bfloat16: empty_embeds_cpu = empty_embeds_cpu.to(torch.float32) empty_embeds_np = empty_embeds_cpu.numpy().astype(np.float16) empty_embeds_path = os.path.join(args.output_dir, "empty_embeds.npy") np.save(empty_embeds_path, empty_embeds_np) logger.info(f"Saved empty_embeds to: {empty_embeds_path} shape={empty_embeds_np.shape} dtype={empty_embeds_np.dtype}") else: logger.info("Skipping text encoder loading (extract_text=False)") if args.extract_video: tokenizer = T5Tokenizer.from_pretrained("google/umt5-base") # ---- video tokenizer video_tokenizer = None if args.extract_video: logger.info(f"Loading video tokenizer: {args.video_tokenizer_model_id}") video_tokenizer = CosmosVideoTokenizer(model_id=args.video_tokenizer_model_id, device=device, dtype=dtype) video_tokenizer.requires_grad_(False) video_tokenizer.eval() else: logger.info("Skipping video tokenizer loading (extract_video=False)") # ---- auto-detect video_root_dir if args.video_root_dir is None: csv_dir = os.path.dirname(args.csv_path) if os.path.exists(os.path.join(csv_dir, "video_reorg")): video_root_dir = os.path.join(csv_dir, "video_reorg") elif os.path.exists(os.path.join(os.path.dirname(csv_dir), "video_reorg")): video_root_dir = os.path.join(os.path.dirname(csv_dir), "video_reorg") else: video_root_dir = csv_dir logger.warning(f"Video directory not found, using CSV directory: {video_root_dir}") else: video_root_dir = args.video_root_dir # ---- dataset dataset = OpenVid1MDataset( csv_path=args.csv_path, video_root_dir=video_root_dir, tokenizer=tokenizer, num_frames=args.num_frames, height=args.video_height, width=args.video_width, text_encoder_architecture=args.text_encoder_architecture, prompt_prefix=args.prompt_prefix, use_random_temporal_crop=False, use_random_crop=False, ) if args.max_samples is not None: dataset.data = dataset.data[:args.max_samples] logger.info(f"Limited dataset to {len(dataset)} samples") if args.resume_from_index > 0: dataset.data = dataset.data[args.resume_from_index:] logger.info(f"Resuming from index {args.resume_from_index}, remaining samples: {len(dataset)}") num_processes = accelerator.num_processes process_index = accelerator.process_index # ---- sampler: DistributedSampler OR IndexListSampler if args.index_file is not None: idx_path = args.index_file.format(rank=process_index) with open(idx_path, "r") as f: wanted_sample_idx = [int(x.strip()) for x in f if x.strip() and not x.strip().startswith("#")] sampler_indices = [] for sample_idx in wanted_sample_idx: global_dataset_idx = sample_idx - args.resume_from_index if 0 <= global_dataset_idx < len(dataset.data): sampler_indices.append(global_dataset_idx) sampler = IndexListSampler(sampler_indices) logger.info(f"[GPU {process_index}] Using index_file={idx_path}, indices={len(sampler_indices)}") else: sampler = DistributedSampler( dataset, num_replicas=num_processes, rank=process_index, shuffle=False, drop_last=False, ) sampler_indices = list(sampler) dataloader = DataLoader( dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers, pin_memory=True, ) # ---- output dirs video_codes_dir = None text_embeddings_dir = None attention_masks_dir = None if args.extract_video: video_codes_dir = os.path.join(args.output_dir, "video_codes") os.makedirs(video_codes_dir, exist_ok=True) if args.extract_text: text_embeddings_dir = os.path.join(args.output_dir, "text_embeddings") os.makedirs(text_embeddings_dir, exist_ok=True) if args.save_attention_mask: attention_masks_dir = os.path.join(args.output_dir, "attention_masks") os.makedirs(attention_masks_dir, exist_ok=True) # ---- load existing metadata shapes (optional) metadata_file = os.path.join(args.output_dir, "metadata.json") existing_shapes = {} if os.path.exists(metadata_file): try: with open(metadata_file, "r") as f: existing_meta = json.load(f) for sample in existing_meta.get("samples", []): idx = sample.get("index") if idx is None: continue existing_shapes[int(idx)] = { "video_code_shape": sample.get("video_code_shape"), "text_embedding_shape": sample.get("text_embedding_shape"), "context_len": sample.get("context_len"), } logger.info(f"[GPU {process_index}] Loaded existing metadata for {len(existing_shapes)} samples") except Exception as e: logger.warning(f"[GPU {process_index}] Failed to load existing metadata: {e}") total_samples = len(dataset) logger.info(f"[GPU {process_index}] Starting feature extraction for {total_samples} samples " f"(process {process_index+1}/{num_processes}), assigned={len(sampler_indices)}") # tokenizer info codebook_size = None mask_token_id = None if args.extract_video and video_tokenizer is not None: codebook_size = getattr(video_tokenizer, "codebook_size", None) mask_token_id = getattr(video_tokenizer, "mask_token_id", None) logger.info(f"[GPU {process_index}] Video tokenizer: codebook_size={codebook_size}, mask_token_id={mask_token_id}") # empty embeds info empty_embeds_shape = None empty_embeds_path = os.path.join(args.output_dir, "empty_embeds.npy") if args.extract_text and accelerator.is_main_process and os.path.exists(empty_embeds_path): try: empty_embeds_np = np.load(empty_embeds_path, mmap_mode="r") empty_embeds_shape = list(empty_embeds_np.shape) logger.info(f"Empty embeds shape: {empty_embeds_shape}") except Exception: pass process_metadata = { "process_index": process_index, "num_samples_this_run": total_samples, "world_size_used": world_size, "rank_used": rank, "extract_video": args.extract_video, "extract_text": args.extract_text, "text_encoder_architecture": args.text_encoder_architecture if args.extract_text else None, "video_tokenizer_model_id": args.video_tokenizer_model_id if args.extract_video else None, "codebook_size": codebook_size, "mask_token_id": mask_token_id, "num_frames": args.num_frames, "video_height": args.video_height, "video_width": args.video_width, "prompt_prefix": args.prompt_prefix, "text_dtype": args.text_dtype if args.extract_text else None, "save_attention_mask": args.save_attention_mask, "empty_embeds_shape": empty_embeds_shape if process_index == 0 else None, "empty_embeds_path": "empty_embeds.npy" if args.extract_text else None, "samples": [], } process_failed_samples = [] process_samples_processed = 0 process_attempted_samples = 0 # counts actual encoded+written attempts (post-skip) with torch.no_grad(): for batch_idx, batch in enumerate( tqdm(dataloader, desc=f"[GPU {process_index}] Extracting", disable=not accelerator.is_main_process) ): batch_size = batch["video"].shape[0] if args.extract_video else batch["prompt_input_ids"].shape[0] local_start_idx = batch_idx * args.batch_size # ---- compute sample indices for this batch (global sample_idx) batch_sample_indices = [] for i in range(batch_size): local_idx = local_start_idx + i if local_idx < len(sampler_indices): global_dataset_idx = sampler_indices[local_idx] sample_idx = args.resume_from_index + global_dataset_idx batch_sample_indices.append(sample_idx) else: batch_sample_indices.append(None) # ---- compute output paths batch_paths = [] for sidx in batch_sample_indices: if sidx is None: batch_paths.append(None) else: batch_paths.append(get_feature_paths(args, video_codes_dir, text_embeddings_dir, attention_masks_dir, sidx)) # ---- determine per-feature need flags need_video = [False] * batch_size need_text = [False] * batch_size need_mask = [False] * batch_size for i, (sidx, paths) in enumerate(zip(batch_sample_indices, batch_paths)): if sidx is None: continue if args.overwrite: need_video[i] = args.extract_video need_text[i] = args.extract_text need_mask[i] = args.extract_text and args.save_attention_mask elif args.skip_existing: if args.extract_video: need_video[i] = not os.path.exists(paths["video"]) if args.extract_text: need_text[i] = not os.path.exists(paths["text"]) if args.save_attention_mask: need_mask[i] = not os.path.exists(paths["mask"]) else: need_video[i] = args.extract_video need_text[i] = args.extract_text need_mask[i] = args.extract_text and args.save_attention_mask need_any = [v or t or m for v, t, m in zip(need_video, need_text, need_mask)] # IMPORTANT FIX: # - normal full run: if no one needs anything, skip this batch # - index_file resume: even if no extraction needed, still record metadata if (not any(need_any)) and (args.index_file is None): continue process_attempted_samples += sum(need_any) if batch_idx == 0: preview = [x for x in batch_sample_indices[:5] if x is not None] logger.info(f"[GPU {process_index}] First batch sample indices: {preview}") logger.info(f"[GPU {process_index}] Need any in first batch: {sum(need_any)}/{len(need_any)}") # ---- encode video for needed samples video_codes = None need_video_idx = [i for i, ok in enumerate(need_video) if ok] map_video_pos = {i: p for p, i in enumerate(need_video_idx)} if args.extract_video and len(need_video_idx) > 0: videos = batch["video"].to(device, non_blocking=True) videos_sel = videos[need_video_idx] try: vc_sel = video_tokenizer.encode(videos_sel) video_codes = vc_sel.detach().cpu().numpy() except Exception as e: logger.error(f"[GPU {process_index}] Failed to encode video batch {batch_idx}: {e}") for i in need_video_idx: sidx = batch_sample_indices[i] if sidx is not None: process_failed_samples.append({"index": sidx, "reason": "video_encoding_failed"}) continue # ---- text context lens + encode needed encoder_hidden_states = None attention_masks = None context_lens_np_full = None if args.extract_text: prompt_input_ids = batch["prompt_input_ids"].to(device, non_blocking=True) if isinstance(prompt_input_ids, (tuple, list)): prompt_input_ids = prompt_input_ids[0] prompt_input_ids = normalize_input_ids(prompt_input_ids).long() pad_id = tokenizer.pad_token_id attention_mask_full = (prompt_input_ids != pad_id).long() context_lens_full = attention_mask_full.sum(dim=-1) context_lens_np_full = context_lens_full.detach().cpu().numpy().astype(np.int32) need_text_idx = [i for i, ok in enumerate(need_text) if ok] map_text_pos = {i: p for p, i in enumerate(need_text_idx)} if len(need_text_idx) > 0: ids_sel = prompt_input_ids[need_text_idx] mask_sel = attention_mask_full[need_text_idx] try: outputs = text_encoder(input_ids=ids_sel, attention_mask=mask_sel) enc = outputs.last_hidden_state.detach().cpu() if enc.dtype == torch.bfloat16: enc = enc.to(torch.float32) encoder_hidden_states = enc.numpy().astype(np.float16) if args.save_attention_mask: attention_masks = mask_sel.detach().cpu().numpy().astype(np.int32) except Exception as e: logger.error(f"[GPU {process_index}] Failed to encode text batch {batch_idx}: {e}") for i in need_text_idx: sidx = batch_sample_indices[i] if sidx is not None: process_failed_samples.append({"index": sidx, "reason": "text_encoding_failed"}) continue else: map_text_pos = {} else: need_text_idx = [] map_text_pos = {} # ---- save per sample + record metadata for i in range(batch_size): local_idx = local_start_idx + i if local_idx >= len(sampler_indices): continue global_dataset_idx = sampler_indices[local_idx] sample_idx = args.resume_from_index + global_dataset_idx paths = batch_paths[i] row = dataset.data[global_dataset_idx] if global_dataset_idx < len(dataset.data) else None if row is None: continue existing_info = existing_shapes.get(sample_idx, {}) video_code_shape = None if args.extract_video: if need_video[i]: if video_codes is None: process_failed_samples.append({"index": sample_idx, "reason": "video_codes_none"}) continue pos = map_video_pos[i] video_code = video_codes[pos].astype(np.int32) try: atomic_save_npy(paths["video"], video_code) video_code_shape = list(video_code.shape) except Exception as e: process_failed_samples.append({"index": sample_idx, "reason": f"video_save_failed: {str(e)}"}) continue else: video_code_shape = existing_info.get("video_code_shape") or safe_mmap_shape(paths["video"]) text_embedding_shape = None if args.extract_text: if need_text[i]: if encoder_hidden_states is None: process_failed_samples.append({"index": sample_idx, "reason": "text_embeddings_none"}) continue pos = map_text_pos[i] text_emb = encoder_hidden_states[pos] try: atomic_save_npy(paths["text"], text_emb) text_embedding_shape = list(text_emb.shape) except Exception as e: process_failed_samples.append({"index": sample_idx, "reason": f"text_save_failed: {str(e)}"}) continue if args.save_attention_mask and need_mask[i]: if attention_masks is None: process_failed_samples.append({"index": sample_idx, "reason": "attention_masks_none"}) continue try: atomic_save_npy(paths["mask"], attention_masks[pos]) except Exception as e: process_failed_samples.append({"index": sample_idx, "reason": f"mask_save_failed: {str(e)}"}) continue else: text_embedding_shape = existing_info.get("text_embedding_shape") or safe_mmap_shape(paths["text"]) sample_meta = { "index": sample_idx, "video_path": row.get("video", ""), "caption": row.get("caption", ""), } if args.extract_video and video_code_shape is not None: sample_meta["video_code_shape"] = video_code_shape if args.extract_text and text_embedding_shape is not None: sample_meta["text_embedding_shape"] = text_embedding_shape if args.extract_text and context_lens_np_full is not None: sample_meta["context_len"] = int(context_lens_np_full[i]) process_metadata["samples"].append(sample_meta) process_samples_processed += 1 # periodic per-process metadata save (atomic, tagged) if process_samples_processed > 0 and (process_samples_processed % 1000 == 0): suffix = f".{args.run_tag}" if args.run_tag else "" process_metadata_file = os.path.join(args.output_dir, f"metadata_process_{process_index}{suffix}.json") process_metadata["num_extracted"] = process_samples_processed process_metadata["failed_samples"] = process_failed_samples atomic_save_json(process_metadata_file, process_metadata, indent=2) logger.info(f"[GPU {process_index}] Progress: {process_samples_processed} samples recorded -> {process_metadata_file}") accelerator.wait_for_everyone() # final per-process metadata save suffix = f".{args.run_tag}" if args.run_tag else "" process_metadata_file = os.path.join(args.output_dir, f"metadata_process_{process_index}{suffix}.json") process_metadata["num_attempted"] = int(process_attempted_samples) process_metadata["num_extracted"] = int(process_samples_processed) process_metadata["num_failed"] = int(len(process_failed_samples)) process_metadata["failed_samples"] = process_failed_samples atomic_save_json(process_metadata_file, process_metadata, indent=2) logger.info(f"[GPU {process_index}] Done: attempted={process_attempted_samples}, extracted(meta)={process_samples_processed}, failed={len(process_failed_samples)}") accelerator.wait_for_everyone() # merge on main process if accelerator.is_main_process: merge_world = args.merge_world_size if args.merge_world_size is not None else world_size logger.info(f"[MERGE] merging world_size={merge_world} (run_tag={args.run_tag})") merge_metadata(args.output_dir, merge_world_size=merge_world, run_tag=args.run_tag) if __name__ == "__main__": main()