43.oT_eV / Meissonic /train /extract_features.py
BryanW's picture
Upload code from /mnt/43.oT_eV
c2925de verified
#!/usr/bin/env python3
"""
Extract video codes and text embeddings from video-text pairs for efficient training.
Adds resume-by-index_file and robust merge:
- --index_file: run only specified global indices (e.g. missing_process_0.txt)
- --run_tag: write new per-rank metadata as metadata_process_{rank}.{tag}.json
- Merge old possibly-truncated metadata_process_*.json + new tagged ones into metadata.json
"""
import argparse
import os
import sys
import logging
from tqdm import tqdm
import torch
import numpy as np
from torch.utils.data import DataLoader, DistributedSampler, Sampler
import json
sys.path.append(os.getcwd())
from train.dataset_utils import OpenVid1MDataset
from src.pipeline_video import CosmosVideoTokenizer
from transformers import T5Tokenizer, T5EncoderModel
from accelerate import Accelerator
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
# -----------------------------
# IO helpers
# -----------------------------
def get_hierarchical_path(base_dir, index):
"""
Structure: base_dir/level1/level2/level3/filename.npy
- level1: index // 1000000 (0-999)
- level2: (index // 1000) % 1000 (0-999)
- level3: index % 1000 (0-999)
"""
level1 = index // 1000000
level2 = (index // 1000) % 1000
level3 = index % 1000
dir_path = os.path.join(base_dir, f"{level1:03d}", f"{level2:03d}", f"{level3:03d}")
file_path = os.path.join(dir_path, f"{index:08d}.npy")
return dir_path, file_path
def atomic_save_npy(path: str, arr: np.ndarray):
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp = path + ".tmp"
with open(tmp, "wb") as f:
np.save(f, arr)
os.replace(tmp, path)
def atomic_save_json(path: str, obj, indent=2):
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp = path + ".tmp"
with open(tmp, "w") as f:
json.dump(obj, f, indent=indent)
os.replace(tmp, path)
def safe_mmap_shape(npy_path: str):
try:
arr = np.load(npy_path, mmap_mode="r")
return list(arr.shape)
except Exception:
return None
def normalize_input_ids(x: torch.Tensor) -> torch.Tensor:
if x.ndim == 3:
if x.shape[1] == 1:
x = x.squeeze(1)
elif x.shape[2] == 1:
x = x.squeeze(2)
else:
raise ValueError(f"Unexpected input_ids shape: {tuple(x.shape)}")
elif x.ndim != 2:
raise ValueError(f"Unexpected input_ids ndim: {x.ndim}, shape={tuple(x.shape)}")
return x
def get_feature_paths(args, video_codes_dir, text_embeddings_dir, attention_masks_dir, sample_idx: int):
paths = {}
if args.extract_video:
_, vp = get_hierarchical_path(video_codes_dir, sample_idx)
paths["video"] = vp
if args.extract_text:
_, tp = get_hierarchical_path(text_embeddings_dir, sample_idx)
paths["text"] = tp
if args.save_attention_mask:
_, ap = get_hierarchical_path(attention_masks_dir, sample_idx)
paths["mask"] = ap
return paths
# -----------------------------
# robust salvage for truncated JSON
# -----------------------------
def iter_samples_salvage(meta_path: str):
"""
Read possibly-truncated metadata_process_*.json and salvage complete objects in "samples":[...].
"""
p = meta_path
if not os.path.exists(p):
return
with open(p, "r", encoding="utf-8-sig") as f:
buf = ""
found = False
while True:
chunk = f.read(1024 * 1024)
if not chunk:
break
buf += chunk
k = buf.find('"samples"')
if k != -1:
b = buf.find("[", k)
if b != -1:
buf = buf[b + 1 :]
found = True
break
if len(buf) > 8 * 1024 * 1024:
buf = buf[-4 * 1024 * 1024 :]
if not found:
return
in_string = False
escape = False
depth = 0
collecting = False
obj = []
def feed(ch):
nonlocal in_string, escape, depth, collecting, obj
if in_string:
if collecting:
obj.append(ch)
if escape:
escape = False
else:
if ch == "\\":
escape = True
elif ch == '"':
in_string = False
return None
if ch == '"':
in_string = True
if collecting:
obj.append(ch)
return None
if not collecting and ch == "]":
return "__END__"
if ch == "{":
if not collecting:
collecting = True
depth = 1
obj = ["{"]
else:
depth += 1
obj.append("{")
return None
if collecting:
obj.append(ch)
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
s = "".join(obj)
collecting = False
obj = []
try:
return json.loads(s)
except Exception:
return "__BAD__"
return None
def consume(text):
for ch in text:
out = feed(ch)
if out == "__END__":
return "__END__"
if isinstance(out, dict):
yield out
end = yield from consume(buf)
if end == "__END__":
return
while True:
chunk = f.read(1024 * 1024)
if not chunk:
break
end = yield from consume(chunk)
if end == "__END__":
return
def merge_metadata(output_dir: str, merge_world_size: int, run_tag: str):
"""
Merge:
- old metadata_process_{r}.json (may be truncated -> salvage)
- new metadata_process_{r}.{run_tag}.json (assumed complete)
into metadata.json (index-dedup, prefer new).
"""
outdir = output_dir
by_idx = {}
# old first
for r in range(merge_world_size):
old_p = os.path.join(outdir, f"metadata_process_{r}.json")
for s in iter_samples_salvage(old_p) or []:
idx = s.get("index", None)
if idx is None:
continue
by_idx[int(idx)] = s
# then new (override)
if run_tag:
for r in range(merge_world_size):
new_p = os.path.join(outdir, f"metadata_process_{r}.{run_tag}.json")
if not os.path.exists(new_p):
continue
try:
with open(new_p, "r") as f:
meta = json.load(f)
for s in meta.get("samples", []):
idx = s.get("index", None)
if idx is None:
continue
by_idx[int(idx)] = s
except Exception as e:
logger.warning(f"Failed to load new meta {new_p}: {e}")
samples = [by_idx[k] for k in sorted(by_idx.keys())]
merged = {
"num_extracted_metadata": len(samples),
"world_size_used": merge_world_size,
"samples": samples,
}
# try pull header info from any available new meta (rank0 preferred)
header_src = None
if run_tag:
p0 = os.path.join(outdir, f"metadata_process_0.{run_tag}.json")
if os.path.exists(p0):
header_src = p0
if header_src is None:
p0_old = os.path.join(outdir, "metadata_process_0.json")
if os.path.exists(p0_old):
header_src = p0_old
if header_src:
# best-effort load; if truncated, skip
try:
with open(header_src, "r") as f:
m0 = json.load(f)
for k in [
"extract_video","extract_text","text_encoder_architecture","video_tokenizer_model_id",
"codebook_size","mask_token_id","num_frames","video_height","video_width",
"prompt_prefix","text_dtype","save_attention_mask","empty_embeds_shape","empty_embeds_path",
"num_samples_original","resume_from_index","num_samples_this_run","num_attempted",
"num_extracted","num_failed","num_processes","ranks_seen"
]:
if k in m0 and m0[k] is not None:
merged[k] = m0[k]
except Exception:
pass
metadata_file = os.path.join(outdir, "metadata.json")
# backup old
if os.path.exists(metadata_file):
bak = os.path.join(outdir, "metadata.json.bak")
try:
os.replace(metadata_file, bak)
logger.info(f"Backed up old metadata.json -> {bak}")
except Exception:
pass
atomic_save_json(metadata_file, merged, indent=2)
logger.info(f"[MERGE] Wrote {metadata_file}, samples={len(samples):,}")
# -----------------------------
# Sampler for index list
# -----------------------------
class IndexListSampler(Sampler):
def __init__(self, indices):
self.indices = list(indices)
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
# -----------------------------
# Args
# -----------------------------
def parse_args():
parser = argparse.ArgumentParser(description="Extract video codes and text embeddings")
parser.add_argument("--csv_path", type=str, required=True, help="Path to OpenVid1M CSV file")
parser.add_argument("--video_root_dir", type=str, default=None, help="Root directory containing video files. If None, auto-detect.")
parser.add_argument("--output_dir", type=str, required=True, help="Output directory to save extracted features")
parser.add_argument(
"--text_encoder_architecture",
type=str,
default="umt5-base",
choices=["umt5-base", "umt5-xxl", "t5"],
help="Text encoder architecture",
)
parser.add_argument(
"--video_tokenizer_model_id",
type=str,
default="Cosmos-1.0-Tokenizer-DV8x16x16",
help="HuggingFace model ID for Cosmos video tokenizer",
)
parser.add_argument("--num_frames", type=int, default=16, help="Number of frames per video")
parser.add_argument("--video_height", type=int, default=480, help="Video height")
parser.add_argument("--video_width", type=int, default=848, help="Video width")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for feature extraction")
parser.add_argument("--num_workers", type=int, default=4, help="Number of dataloader workers")
parser.add_argument("--max_samples", type=int, default=None, help="Max samples (for testing). If None, process all.")
parser.add_argument("--resume_from_index", type=int, default=0, help="Resume extraction from this index")
parser.add_argument("--prompt_prefix", type=str, default=None, help="Prefix to add to prompts")
parser.add_argument("--extract_video", action="store_true", default=False, help="Extract video codes")
parser.add_argument("--extract_text", action="store_true", default=False, help="Extract text embeddings")
parser.add_argument("--text_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"], help="Text encoder dtype")
parser.add_argument("--skip_existing", action="store_true", help="Skip samples whose feature .npy already exist.")
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing .npy (disables skip_existing).")
group = parser.add_mutually_exclusive_group()
group.add_argument("--save_attention_mask", dest="save_attention_mask", action="store_true",
help="Save attention mask per sample (default: on).")
group.add_argument("--no_save_attention_mask", dest="save_attention_mask", action="store_false",
help="Do NOT save attention mask per sample.")
parser.set_defaults(save_attention_mask=True)
# resume / merge additions
parser.add_argument(
"--index_file",
type=str,
default=None,
help="Text file with one global sample index per line. "
"Can contain '{rank}' placeholder, e.g. missing_process_{rank}.txt"
)
parser.add_argument(
"--run_tag",
type=str,
default=None,
help="Tag for new per-rank metadata file: metadata_process_{rank}.{run_tag}.json"
)
parser.add_argument(
"--merge_world_size",
type=int,
default=None,
help="How many ranks to merge for final metadata.json. "
"Set to 8 for your OpenVid1M run even if you resume with 1 GPU."
)
return parser.parse_args()
# -----------------------------
# Main
# -----------------------------
def main():
args = parse_args()
accelerator = Accelerator()
rank = accelerator.process_index
world_size = accelerator.num_processes
logger.info(f"Process {rank}/{world_size} on device {accelerator.device}")
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
logger.info(f"Output directory: {args.output_dir}")
logger.info(f"Using {accelerator.num_processes} GPUs for parallel extraction")
logger.info(f"Extract video codes: {args.extract_video}")
logger.info(f"Extract text embeddings: {args.extract_text}")
logger.info(f"skip_existing={args.skip_existing}, overwrite={args.overwrite}, save_attention_mask={args.save_attention_mask}")
logger.info(f"index_file={args.index_file}, run_tag={args.run_tag}, merge_world_size={args.merge_world_size}")
if not args.extract_video and not args.extract_text:
raise ValueError("At least one feature type must be enabled. Use --extract_video and/or --extract_text.")
device = accelerator.device
dtype = torch.float32
# ---- text encoder/tokenizer
text_encoder = None
tokenizer = None
if args.extract_text:
logger.info(f"Loading text encoder: {args.text_encoder_architecture} with dtype {args.text_dtype}")
text_dtype_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
text_dtype = text_dtype_map[args.text_dtype]
if args.text_encoder_architecture == "umt5-base":
model_id = "google/umt5-base"
elif args.text_encoder_architecture == "umt5-xxl":
model_id = "google/umt5-xxl"
elif args.text_encoder_architecture == "t5":
model_id = "t5-base"
else:
raise ValueError(f"Unknown text encoder: {args.text_encoder_architecture}")
text_encoder = T5EncoderModel.from_pretrained(model_id, torch_dtype=text_dtype)
tokenizer = T5Tokenizer.from_pretrained(model_id)
text_encoder.to(device=device)
text_encoder.eval()
text_encoder.requires_grad_(False)
# empty_embeds on main process
if accelerator.is_main_process:
logger.info("Extracting empty_embeds for conditional dropout...")
with torch.no_grad():
empty = tokenizer("", return_tensors="pt", padding="max_length", max_length=512, truncation=True)
empty_ids = empty["input_ids"].to(device)
empty_mask = empty["attention_mask"].to(device)
outputs = text_encoder(input_ids=empty_ids, attention_mask=empty_mask)
empty_embeds = outputs.last_hidden_state # [1, 512, D]
empty_embeds_cpu = empty_embeds.cpu()
if empty_embeds_cpu.dtype == torch.bfloat16:
empty_embeds_cpu = empty_embeds_cpu.to(torch.float32)
empty_embeds_np = empty_embeds_cpu.numpy().astype(np.float16)
empty_embeds_path = os.path.join(args.output_dir, "empty_embeds.npy")
np.save(empty_embeds_path, empty_embeds_np)
logger.info(f"Saved empty_embeds to: {empty_embeds_path} shape={empty_embeds_np.shape} dtype={empty_embeds_np.dtype}")
else:
logger.info("Skipping text encoder loading (extract_text=False)")
if args.extract_video:
tokenizer = T5Tokenizer.from_pretrained("google/umt5-base")
# ---- video tokenizer
video_tokenizer = None
if args.extract_video:
logger.info(f"Loading video tokenizer: {args.video_tokenizer_model_id}")
video_tokenizer = CosmosVideoTokenizer(model_id=args.video_tokenizer_model_id, device=device, dtype=dtype)
video_tokenizer.requires_grad_(False)
video_tokenizer.eval()
else:
logger.info("Skipping video tokenizer loading (extract_video=False)")
# ---- auto-detect video_root_dir
if args.video_root_dir is None:
csv_dir = os.path.dirname(args.csv_path)
if os.path.exists(os.path.join(csv_dir, "video_reorg")):
video_root_dir = os.path.join(csv_dir, "video_reorg")
elif os.path.exists(os.path.join(os.path.dirname(csv_dir), "video_reorg")):
video_root_dir = os.path.join(os.path.dirname(csv_dir), "video_reorg")
else:
video_root_dir = csv_dir
logger.warning(f"Video directory not found, using CSV directory: {video_root_dir}")
else:
video_root_dir = args.video_root_dir
# ---- dataset
dataset = OpenVid1MDataset(
csv_path=args.csv_path,
video_root_dir=video_root_dir,
tokenizer=tokenizer,
num_frames=args.num_frames,
height=args.video_height,
width=args.video_width,
text_encoder_architecture=args.text_encoder_architecture,
prompt_prefix=args.prompt_prefix,
use_random_temporal_crop=False,
use_random_crop=False,
)
if args.max_samples is not None:
dataset.data = dataset.data[:args.max_samples]
logger.info(f"Limited dataset to {len(dataset)} samples")
if args.resume_from_index > 0:
dataset.data = dataset.data[args.resume_from_index:]
logger.info(f"Resuming from index {args.resume_from_index}, remaining samples: {len(dataset)}")
num_processes = accelerator.num_processes
process_index = accelerator.process_index
# ---- sampler: DistributedSampler OR IndexListSampler
if args.index_file is not None:
idx_path = args.index_file.format(rank=process_index)
with open(idx_path, "r") as f:
wanted_sample_idx = [int(x.strip()) for x in f if x.strip() and not x.strip().startswith("#")]
sampler_indices = []
for sample_idx in wanted_sample_idx:
global_dataset_idx = sample_idx - args.resume_from_index
if 0 <= global_dataset_idx < len(dataset.data):
sampler_indices.append(global_dataset_idx)
sampler = IndexListSampler(sampler_indices)
logger.info(f"[GPU {process_index}] Using index_file={idx_path}, indices={len(sampler_indices)}")
else:
sampler = DistributedSampler(
dataset,
num_replicas=num_processes,
rank=process_index,
shuffle=False,
drop_last=False,
)
sampler_indices = list(sampler)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
)
# ---- output dirs
video_codes_dir = None
text_embeddings_dir = None
attention_masks_dir = None
if args.extract_video:
video_codes_dir = os.path.join(args.output_dir, "video_codes")
os.makedirs(video_codes_dir, exist_ok=True)
if args.extract_text:
text_embeddings_dir = os.path.join(args.output_dir, "text_embeddings")
os.makedirs(text_embeddings_dir, exist_ok=True)
if args.save_attention_mask:
attention_masks_dir = os.path.join(args.output_dir, "attention_masks")
os.makedirs(attention_masks_dir, exist_ok=True)
# ---- load existing metadata shapes (optional)
metadata_file = os.path.join(args.output_dir, "metadata.json")
existing_shapes = {}
if os.path.exists(metadata_file):
try:
with open(metadata_file, "r") as f:
existing_meta = json.load(f)
for sample in existing_meta.get("samples", []):
idx = sample.get("index")
if idx is None:
continue
existing_shapes[int(idx)] = {
"video_code_shape": sample.get("video_code_shape"),
"text_embedding_shape": sample.get("text_embedding_shape"),
"context_len": sample.get("context_len"),
}
logger.info(f"[GPU {process_index}] Loaded existing metadata for {len(existing_shapes)} samples")
except Exception as e:
logger.warning(f"[GPU {process_index}] Failed to load existing metadata: {e}")
total_samples = len(dataset)
logger.info(f"[GPU {process_index}] Starting feature extraction for {total_samples} samples "
f"(process {process_index+1}/{num_processes}), assigned={len(sampler_indices)}")
# tokenizer info
codebook_size = None
mask_token_id = None
if args.extract_video and video_tokenizer is not None:
codebook_size = getattr(video_tokenizer, "codebook_size", None)
mask_token_id = getattr(video_tokenizer, "mask_token_id", None)
logger.info(f"[GPU {process_index}] Video tokenizer: codebook_size={codebook_size}, mask_token_id={mask_token_id}")
# empty embeds info
empty_embeds_shape = None
empty_embeds_path = os.path.join(args.output_dir, "empty_embeds.npy")
if args.extract_text and accelerator.is_main_process and os.path.exists(empty_embeds_path):
try:
empty_embeds_np = np.load(empty_embeds_path, mmap_mode="r")
empty_embeds_shape = list(empty_embeds_np.shape)
logger.info(f"Empty embeds shape: {empty_embeds_shape}")
except Exception:
pass
process_metadata = {
"process_index": process_index,
"num_samples_this_run": total_samples,
"world_size_used": world_size,
"rank_used": rank,
"extract_video": args.extract_video,
"extract_text": args.extract_text,
"text_encoder_architecture": args.text_encoder_architecture if args.extract_text else None,
"video_tokenizer_model_id": args.video_tokenizer_model_id if args.extract_video else None,
"codebook_size": codebook_size,
"mask_token_id": mask_token_id,
"num_frames": args.num_frames,
"video_height": args.video_height,
"video_width": args.video_width,
"prompt_prefix": args.prompt_prefix,
"text_dtype": args.text_dtype if args.extract_text else None,
"save_attention_mask": args.save_attention_mask,
"empty_embeds_shape": empty_embeds_shape if process_index == 0 else None,
"empty_embeds_path": "empty_embeds.npy" if args.extract_text else None,
"samples": [],
}
process_failed_samples = []
process_samples_processed = 0
process_attempted_samples = 0 # counts actual encoded+written attempts (post-skip)
with torch.no_grad():
for batch_idx, batch in enumerate(
tqdm(dataloader, desc=f"[GPU {process_index}] Extracting", disable=not accelerator.is_main_process)
):
batch_size = batch["video"].shape[0] if args.extract_video else batch["prompt_input_ids"].shape[0]
local_start_idx = batch_idx * args.batch_size
# ---- compute sample indices for this batch (global sample_idx)
batch_sample_indices = []
for i in range(batch_size):
local_idx = local_start_idx + i
if local_idx < len(sampler_indices):
global_dataset_idx = sampler_indices[local_idx]
sample_idx = args.resume_from_index + global_dataset_idx
batch_sample_indices.append(sample_idx)
else:
batch_sample_indices.append(None)
# ---- compute output paths
batch_paths = []
for sidx in batch_sample_indices:
if sidx is None:
batch_paths.append(None)
else:
batch_paths.append(get_feature_paths(args, video_codes_dir, text_embeddings_dir, attention_masks_dir, sidx))
# ---- determine per-feature need flags
need_video = [False] * batch_size
need_text = [False] * batch_size
need_mask = [False] * batch_size
for i, (sidx, paths) in enumerate(zip(batch_sample_indices, batch_paths)):
if sidx is None:
continue
if args.overwrite:
need_video[i] = args.extract_video
need_text[i] = args.extract_text
need_mask[i] = args.extract_text and args.save_attention_mask
elif args.skip_existing:
if args.extract_video:
need_video[i] = not os.path.exists(paths["video"])
if args.extract_text:
need_text[i] = not os.path.exists(paths["text"])
if args.save_attention_mask:
need_mask[i] = not os.path.exists(paths["mask"])
else:
need_video[i] = args.extract_video
need_text[i] = args.extract_text
need_mask[i] = args.extract_text and args.save_attention_mask
need_any = [v or t or m for v, t, m in zip(need_video, need_text, need_mask)]
# IMPORTANT FIX:
# - normal full run: if no one needs anything, skip this batch
# - index_file resume: even if no extraction needed, still record metadata
if (not any(need_any)) and (args.index_file is None):
continue
process_attempted_samples += sum(need_any)
if batch_idx == 0:
preview = [x for x in batch_sample_indices[:5] if x is not None]
logger.info(f"[GPU {process_index}] First batch sample indices: {preview}")
logger.info(f"[GPU {process_index}] Need any in first batch: {sum(need_any)}/{len(need_any)}")
# ---- encode video for needed samples
video_codes = None
need_video_idx = [i for i, ok in enumerate(need_video) if ok]
map_video_pos = {i: p for p, i in enumerate(need_video_idx)}
if args.extract_video and len(need_video_idx) > 0:
videos = batch["video"].to(device, non_blocking=True)
videos_sel = videos[need_video_idx]
try:
vc_sel = video_tokenizer.encode(videos_sel)
video_codes = vc_sel.detach().cpu().numpy()
except Exception as e:
logger.error(f"[GPU {process_index}] Failed to encode video batch {batch_idx}: {e}")
for i in need_video_idx:
sidx = batch_sample_indices[i]
if sidx is not None:
process_failed_samples.append({"index": sidx, "reason": "video_encoding_failed"})
continue
# ---- text context lens + encode needed
encoder_hidden_states = None
attention_masks = None
context_lens_np_full = None
if args.extract_text:
prompt_input_ids = batch["prompt_input_ids"].to(device, non_blocking=True)
if isinstance(prompt_input_ids, (tuple, list)):
prompt_input_ids = prompt_input_ids[0]
prompt_input_ids = normalize_input_ids(prompt_input_ids).long()
pad_id = tokenizer.pad_token_id
attention_mask_full = (prompt_input_ids != pad_id).long()
context_lens_full = attention_mask_full.sum(dim=-1)
context_lens_np_full = context_lens_full.detach().cpu().numpy().astype(np.int32)
need_text_idx = [i for i, ok in enumerate(need_text) if ok]
map_text_pos = {i: p for p, i in enumerate(need_text_idx)}
if len(need_text_idx) > 0:
ids_sel = prompt_input_ids[need_text_idx]
mask_sel = attention_mask_full[need_text_idx]
try:
outputs = text_encoder(input_ids=ids_sel, attention_mask=mask_sel)
enc = outputs.last_hidden_state.detach().cpu()
if enc.dtype == torch.bfloat16:
enc = enc.to(torch.float32)
encoder_hidden_states = enc.numpy().astype(np.float16)
if args.save_attention_mask:
attention_masks = mask_sel.detach().cpu().numpy().astype(np.int32)
except Exception as e:
logger.error(f"[GPU {process_index}] Failed to encode text batch {batch_idx}: {e}")
for i in need_text_idx:
sidx = batch_sample_indices[i]
if sidx is not None:
process_failed_samples.append({"index": sidx, "reason": "text_encoding_failed"})
continue
else:
map_text_pos = {}
else:
need_text_idx = []
map_text_pos = {}
# ---- save per sample + record metadata
for i in range(batch_size):
local_idx = local_start_idx + i
if local_idx >= len(sampler_indices):
continue
global_dataset_idx = sampler_indices[local_idx]
sample_idx = args.resume_from_index + global_dataset_idx
paths = batch_paths[i]
row = dataset.data[global_dataset_idx] if global_dataset_idx < len(dataset.data) else None
if row is None:
continue
existing_info = existing_shapes.get(sample_idx, {})
video_code_shape = None
if args.extract_video:
if need_video[i]:
if video_codes is None:
process_failed_samples.append({"index": sample_idx, "reason": "video_codes_none"})
continue
pos = map_video_pos[i]
video_code = video_codes[pos].astype(np.int32)
try:
atomic_save_npy(paths["video"], video_code)
video_code_shape = list(video_code.shape)
except Exception as e:
process_failed_samples.append({"index": sample_idx, "reason": f"video_save_failed: {str(e)}"})
continue
else:
video_code_shape = existing_info.get("video_code_shape") or safe_mmap_shape(paths["video"])
text_embedding_shape = None
if args.extract_text:
if need_text[i]:
if encoder_hidden_states is None:
process_failed_samples.append({"index": sample_idx, "reason": "text_embeddings_none"})
continue
pos = map_text_pos[i]
text_emb = encoder_hidden_states[pos]
try:
atomic_save_npy(paths["text"], text_emb)
text_embedding_shape = list(text_emb.shape)
except Exception as e:
process_failed_samples.append({"index": sample_idx, "reason": f"text_save_failed: {str(e)}"})
continue
if args.save_attention_mask and need_mask[i]:
if attention_masks is None:
process_failed_samples.append({"index": sample_idx, "reason": "attention_masks_none"})
continue
try:
atomic_save_npy(paths["mask"], attention_masks[pos])
except Exception as e:
process_failed_samples.append({"index": sample_idx, "reason": f"mask_save_failed: {str(e)}"})
continue
else:
text_embedding_shape = existing_info.get("text_embedding_shape") or safe_mmap_shape(paths["text"])
sample_meta = {
"index": sample_idx,
"video_path": row.get("video", ""),
"caption": row.get("caption", ""),
}
if args.extract_video and video_code_shape is not None:
sample_meta["video_code_shape"] = video_code_shape
if args.extract_text and text_embedding_shape is not None:
sample_meta["text_embedding_shape"] = text_embedding_shape
if args.extract_text and context_lens_np_full is not None:
sample_meta["context_len"] = int(context_lens_np_full[i])
process_metadata["samples"].append(sample_meta)
process_samples_processed += 1
# periodic per-process metadata save (atomic, tagged)
if process_samples_processed > 0 and (process_samples_processed % 1000 == 0):
suffix = f".{args.run_tag}" if args.run_tag else ""
process_metadata_file = os.path.join(args.output_dir, f"metadata_process_{process_index}{suffix}.json")
process_metadata["num_extracted"] = process_samples_processed
process_metadata["failed_samples"] = process_failed_samples
atomic_save_json(process_metadata_file, process_metadata, indent=2)
logger.info(f"[GPU {process_index}] Progress: {process_samples_processed} samples recorded -> {process_metadata_file}")
accelerator.wait_for_everyone()
# final per-process metadata save
suffix = f".{args.run_tag}" if args.run_tag else ""
process_metadata_file = os.path.join(args.output_dir, f"metadata_process_{process_index}{suffix}.json")
process_metadata["num_attempted"] = int(process_attempted_samples)
process_metadata["num_extracted"] = int(process_samples_processed)
process_metadata["num_failed"] = int(len(process_failed_samples))
process_metadata["failed_samples"] = process_failed_samples
atomic_save_json(process_metadata_file, process_metadata, indent=2)
logger.info(f"[GPU {process_index}] Done: attempted={process_attempted_samples}, extracted(meta)={process_samples_processed}, failed={len(process_failed_samples)}")
accelerator.wait_for_everyone()
# merge on main process
if accelerator.is_main_process:
merge_world = args.merge_world_size if args.merge_world_size is not None else world_size
logger.info(f"[MERGE] merging world_size={merge_world} (run_tag={args.run_tag})")
merge_metadata(args.output_dir, merge_world_size=merge_world, run_tag=args.run_tag)
if __name__ == "__main__":
main()