43.oT_eV / Meissonic /train /extract_check_missing.py
BryanW's picture
Upload code from /mnt/43.oT_eV
c2925de verified
import argparse
import json
from pathlib import Path
def count_csv_rows(csv_path: Path) -> int:
# 只数行数(减去表头),不读入内存
n = 0
with csv_path.open("rb") as f:
for _ in f:
n += 1
return max(0, n - 1)
def iter_samples_from_pos(f, initial_buf: str, stop_on_trunc=True):
"""
从 '"samples": [' 后面开始,做 brace-matching 流式提取每个 { ... } 对象。
即使文件末尾截断,也会尽量返回已完整解析的对象;末尾不完整对象会被丢弃。
"""
in_string = False
escape = False
depth = 0
collecting = False
obj_chars = []
def feed(ch: str):
nonlocal in_string, escape, depth, collecting, obj_chars
if in_string:
if collecting:
obj_chars.append(ch)
if escape:
escape = False
else:
if ch == "\\":
escape = True
elif ch == '"':
in_string = False
return None
# not in string
if ch == '"':
in_string = True
if collecting:
obj_chars.append(ch)
return None
# samples array end
if not collecting and ch == "]":
return "__END__"
if ch == "{":
if not collecting:
collecting = True
depth = 1
obj_chars = ["{"]
else:
depth += 1
obj_chars.append("{")
return None
if collecting:
obj_chars.append(ch)
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
s = "".join(obj_chars)
collecting = False
obj_chars = []
try:
return json.loads(s)
except json.JSONDecodeError:
# 这一般只会发生在对象本身被截断/污染
return "__BAD_OBJECT__"
return None
def consume(text: str):
for ch in text:
out = feed(ch)
if out == "__END__":
return "__END__"
if isinstance(out, dict):
yield out
# bad object: skip
return None
endflag = yield from consume(initial_buf)
if endflag == "__END__":
return
while True:
chunk = f.read(1024 * 1024)
if not chunk:
break
endflag = yield from consume(chunk)
if endflag == "__END__":
return
# 文件结束:如果还在 collecting,说明截断;丢弃最后不完整对象
return
def iter_samples_salvage(meta_path: Path):
"""
容错读取 metadata_process_i.json 中 samples 数组里的每条 sample。
如果文件截断,仍尽量读出前面完整部分。
"""
with meta_path.open("r", encoding="utf-8-sig") as f:
buf = ""
found = False
# 找到 '"samples"' 后的 '['
while True:
chunk = f.read(1024 * 1024)
if not chunk:
break
buf += chunk
k = buf.find('"samples"')
if k != -1:
b = buf.find("[", k)
if b != -1:
buf = buf[b + 1 :] # 从 '[' 后开始
found = True
break
# 控制 buf 不无限增长
if len(buf) > 8 * 1024 * 1024:
buf = buf[-4 * 1024 * 1024 :]
if not found:
return # 没找到 samples(文件太坏/格式不同)
yield from iter_samples_from_pos(f, buf)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--csv", required=True, help="OpenVid1M_reorganized.csv 路径")
ap.add_argument("--root", required=True, help="extracted_features_* 根目录(含 metadata_process_*.json)")
ap.add_argument("--world_size", type=int, default=8)
ap.add_argument("--processes", type=int, default=8)
args = ap.parse_args()
csv_path = Path(args.csv)
root = Path(args.root)
assert csv_path.exists(), f"CSV not found: {csv_path}"
assert root.exists(), f"root not found: {root}"
N = count_csv_rows(csv_path)
print(f"CSV rows (N, 0-based indices 0..N-1): {N}")
print(f"world_size={args.world_size}, processes={args.processes}")
print("-" * 80)
all_missing = []
for r in range(args.processes):
meta = root / f"metadata_process_{r}.json"
if not meta.exists():
print(f"[rank {r}] metadata missing file: {meta}")
continue
seen = set()
total_samples_parsed = 0
bad_or_missing_index = 0
for s in iter_samples_salvage(meta):
total_samples_parsed += 1
idx = s.get("index", None)
if idx is None:
bad_or_missing_index += 1
continue
try:
idx = int(idx)
except Exception:
bad_or_missing_index += 1
continue
seen.add(idx)
# 理论应有 index:i, i+world_size, ...
exp_count = 0
missing = []
if r < N:
for idx in range(r, N, args.world_size):
exp_count += 1
if idx not in seen:
missing.append(idx)
out_txt = root / f"missing_process_{r}.txt"
out_txt.write_text("\n".join(map(str, missing)) + ("\n" if missing else ""), encoding="utf-8")
all_missing.extend(missing)
# 一些诊断信息:如果 parsed 样本数远小于 exp_count,基本就是 metadata 截断/没写完
coverage = (len(seen) / exp_count * 100.0) if exp_count > 0 else 0.0
print(
f"[rank {r}] expected={exp_count:,} parsed_samples={total_samples_parsed:,} "
f"unique_index={len(seen):,} idx_bad={bad_or_missing_index:,} "
f"missing={len(missing):,} coverage={coverage:.2f}% -> {out_txt.name}"
)
all_missing = sorted(set(all_missing))
out_all = root / "missing_all.txt"
out_all.write_text("\n".join(map(str, all_missing)) + ("\n" if all_missing else ""), encoding="utf-8")
print("-" * 80)
print(f"TOTAL missing unique indices vs CSV = {len(all_missing):,} -> {out_all}")
if __name__ == "__main__":
main()