File size: 6,507 Bytes
c2925de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import argparse
import json
from pathlib import Path
def count_csv_rows(csv_path: Path) -> int:
# 只数行数(减去表头),不读入内存
n = 0
with csv_path.open("rb") as f:
for _ in f:
n += 1
return max(0, n - 1)
def iter_samples_from_pos(f, initial_buf: str, stop_on_trunc=True):
"""
从 '"samples": [' 后面开始,做 brace-matching 流式提取每个 { ... } 对象。
即使文件末尾截断,也会尽量返回已完整解析的对象;末尾不完整对象会被丢弃。
"""
in_string = False
escape = False
depth = 0
collecting = False
obj_chars = []
def feed(ch: str):
nonlocal in_string, escape, depth, collecting, obj_chars
if in_string:
if collecting:
obj_chars.append(ch)
if escape:
escape = False
else:
if ch == "\\":
escape = True
elif ch == '"':
in_string = False
return None
# not in string
if ch == '"':
in_string = True
if collecting:
obj_chars.append(ch)
return None
# samples array end
if not collecting and ch == "]":
return "__END__"
if ch == "{":
if not collecting:
collecting = True
depth = 1
obj_chars = ["{"]
else:
depth += 1
obj_chars.append("{")
return None
if collecting:
obj_chars.append(ch)
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
s = "".join(obj_chars)
collecting = False
obj_chars = []
try:
return json.loads(s)
except json.JSONDecodeError:
# 这一般只会发生在对象本身被截断/污染
return "__BAD_OBJECT__"
return None
def consume(text: str):
for ch in text:
out = feed(ch)
if out == "__END__":
return "__END__"
if isinstance(out, dict):
yield out
# bad object: skip
return None
endflag = yield from consume(initial_buf)
if endflag == "__END__":
return
while True:
chunk = f.read(1024 * 1024)
if not chunk:
break
endflag = yield from consume(chunk)
if endflag == "__END__":
return
# 文件结束:如果还在 collecting,说明截断;丢弃最后不完整对象
return
def iter_samples_salvage(meta_path: Path):
"""
容错读取 metadata_process_i.json 中 samples 数组里的每条 sample。
如果文件截断,仍尽量读出前面完整部分。
"""
with meta_path.open("r", encoding="utf-8-sig") as f:
buf = ""
found = False
# 找到 '"samples"' 后的 '['
while True:
chunk = f.read(1024 * 1024)
if not chunk:
break
buf += chunk
k = buf.find('"samples"')
if k != -1:
b = buf.find("[", k)
if b != -1:
buf = buf[b + 1 :] # 从 '[' 后开始
found = True
break
# 控制 buf 不无限增长
if len(buf) > 8 * 1024 * 1024:
buf = buf[-4 * 1024 * 1024 :]
if not found:
return # 没找到 samples(文件太坏/格式不同)
yield from iter_samples_from_pos(f, buf)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--csv", required=True, help="OpenVid1M_reorganized.csv 路径")
ap.add_argument("--root", required=True, help="extracted_features_* 根目录(含 metadata_process_*.json)")
ap.add_argument("--world_size", type=int, default=8)
ap.add_argument("--processes", type=int, default=8)
args = ap.parse_args()
csv_path = Path(args.csv)
root = Path(args.root)
assert csv_path.exists(), f"CSV not found: {csv_path}"
assert root.exists(), f"root not found: {root}"
N = count_csv_rows(csv_path)
print(f"CSV rows (N, 0-based indices 0..N-1): {N}")
print(f"world_size={args.world_size}, processes={args.processes}")
print("-" * 80)
all_missing = []
for r in range(args.processes):
meta = root / f"metadata_process_{r}.json"
if not meta.exists():
print(f"[rank {r}] metadata missing file: {meta}")
continue
seen = set()
total_samples_parsed = 0
bad_or_missing_index = 0
for s in iter_samples_salvage(meta):
total_samples_parsed += 1
idx = s.get("index", None)
if idx is None:
bad_or_missing_index += 1
continue
try:
idx = int(idx)
except Exception:
bad_or_missing_index += 1
continue
seen.add(idx)
# 理论应有 index:i, i+world_size, ...
exp_count = 0
missing = []
if r < N:
for idx in range(r, N, args.world_size):
exp_count += 1
if idx not in seen:
missing.append(idx)
out_txt = root / f"missing_process_{r}.txt"
out_txt.write_text("\n".join(map(str, missing)) + ("\n" if missing else ""), encoding="utf-8")
all_missing.extend(missing)
# 一些诊断信息:如果 parsed 样本数远小于 exp_count,基本就是 metadata 截断/没写完
coverage = (len(seen) / exp_count * 100.0) if exp_count > 0 else 0.0
print(
f"[rank {r}] expected={exp_count:,} parsed_samples={total_samples_parsed:,} "
f"unique_index={len(seen):,} idx_bad={bad_or_missing_index:,} "
f"missing={len(missing):,} coverage={coverage:.2f}% -> {out_txt.name}"
)
all_missing = sorted(set(all_missing))
out_all = root / "missing_all.txt"
out_all.write_text("\n".join(map(str, all_missing)) + ("\n" if all_missing else ""), encoding="utf-8")
print("-" * 80)
print(f"TOTAL missing unique indices vs CSV = {len(all_missing):,} -> {out_all}")
if __name__ == "__main__":
main() |