import argparse import json from pathlib import Path def count_csv_rows(csv_path: Path) -> int: # 只数行数(减去表头),不读入内存 n = 0 with csv_path.open("rb") as f: for _ in f: n += 1 return max(0, n - 1) def iter_samples_from_pos(f, initial_buf: str, stop_on_trunc=True): """ 从 '"samples": [' 后面开始,做 brace-matching 流式提取每个 { ... } 对象。 即使文件末尾截断,也会尽量返回已完整解析的对象;末尾不完整对象会被丢弃。 """ in_string = False escape = False depth = 0 collecting = False obj_chars = [] def feed(ch: str): nonlocal in_string, escape, depth, collecting, obj_chars if in_string: if collecting: obj_chars.append(ch) if escape: escape = False else: if ch == "\\": escape = True elif ch == '"': in_string = False return None # not in string if ch == '"': in_string = True if collecting: obj_chars.append(ch) return None # samples array end if not collecting and ch == "]": return "__END__" if ch == "{": if not collecting: collecting = True depth = 1 obj_chars = ["{"] else: depth += 1 obj_chars.append("{") return None if collecting: obj_chars.append(ch) if ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: s = "".join(obj_chars) collecting = False obj_chars = [] try: return json.loads(s) except json.JSONDecodeError: # 这一般只会发生在对象本身被截断/污染 return "__BAD_OBJECT__" return None def consume(text: str): for ch in text: out = feed(ch) if out == "__END__": return "__END__" if isinstance(out, dict): yield out # bad object: skip return None endflag = yield from consume(initial_buf) if endflag == "__END__": return while True: chunk = f.read(1024 * 1024) if not chunk: break endflag = yield from consume(chunk) if endflag == "__END__": return # 文件结束:如果还在 collecting,说明截断;丢弃最后不完整对象 return def iter_samples_salvage(meta_path: Path): """ 容错读取 metadata_process_i.json 中 samples 数组里的每条 sample。 如果文件截断,仍尽量读出前面完整部分。 """ with meta_path.open("r", encoding="utf-8-sig") as f: buf = "" found = False # 找到 '"samples"' 后的 '[' while True: chunk = f.read(1024 * 1024) if not chunk: break buf += chunk k = buf.find('"samples"') if k != -1: b = buf.find("[", k) if b != -1: buf = buf[b + 1 :] # 从 '[' 后开始 found = True break # 控制 buf 不无限增长 if len(buf) > 8 * 1024 * 1024: buf = buf[-4 * 1024 * 1024 :] if not found: return # 没找到 samples(文件太坏/格式不同) yield from iter_samples_from_pos(f, buf) def main(): ap = argparse.ArgumentParser() ap.add_argument("--csv", required=True, help="OpenVid1M_reorganized.csv 路径") ap.add_argument("--root", required=True, help="extracted_features_* 根目录(含 metadata_process_*.json)") ap.add_argument("--world_size", type=int, default=8) ap.add_argument("--processes", type=int, default=8) args = ap.parse_args() csv_path = Path(args.csv) root = Path(args.root) assert csv_path.exists(), f"CSV not found: {csv_path}" assert root.exists(), f"root not found: {root}" N = count_csv_rows(csv_path) print(f"CSV rows (N, 0-based indices 0..N-1): {N}") print(f"world_size={args.world_size}, processes={args.processes}") print("-" * 80) all_missing = [] for r in range(args.processes): meta = root / f"metadata_process_{r}.json" if not meta.exists(): print(f"[rank {r}] metadata missing file: {meta}") continue seen = set() total_samples_parsed = 0 bad_or_missing_index = 0 for s in iter_samples_salvage(meta): total_samples_parsed += 1 idx = s.get("index", None) if idx is None: bad_or_missing_index += 1 continue try: idx = int(idx) except Exception: bad_or_missing_index += 1 continue seen.add(idx) # 理论应有 index:i, i+world_size, ... exp_count = 0 missing = [] if r < N: for idx in range(r, N, args.world_size): exp_count += 1 if idx not in seen: missing.append(idx) out_txt = root / f"missing_process_{r}.txt" out_txt.write_text("\n".join(map(str, missing)) + ("\n" if missing else ""), encoding="utf-8") all_missing.extend(missing) # 一些诊断信息:如果 parsed 样本数远小于 exp_count,基本就是 metadata 截断/没写完 coverage = (len(seen) / exp_count * 100.0) if exp_count > 0 else 0.0 print( f"[rank {r}] expected={exp_count:,} parsed_samples={total_samples_parsed:,} " f"unique_index={len(seen):,} idx_bad={bad_or_missing_index:,} " f"missing={len(missing):,} coverage={coverage:.2f}% -> {out_txt.name}" ) all_missing = sorted(set(all_missing)) out_all = root / "missing_all.txt" out_all.write_text("\n".join(map(str, all_missing)) + ("\n" if all_missing else ""), encoding="utf-8") print("-" * 80) print(f"TOTAL missing unique indices vs CSV = {len(all_missing):,} -> {out_all}") if __name__ == "__main__": main()