enhanced-replica-model-pack / scripts /train_eval /data_checks /inspect_dataset_distribution.py
LUCIFerace's picture
Add files using upload-large-folder tool
6b6f412 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""E00: Dataset schema validation and distribution summary.
This is essentially a smoke test for the 公共模块 layer and the dataset directory.
It loads every dataset, validates schema, and writes a rollup report.
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import Counter
from pathlib import Path
import numpy as np
import pandas as pd
REPO_ROOT = Path(__file__).resolve()
while REPO_ROOT != REPO_ROOT.parent and not (REPO_ROOT / "src").exists():
REPO_ROOT = REPO_ROOT.parent
for _candidate in (REPO_ROOT, REPO_ROOT / "src"):
_candidate_str = str(_candidate)
if _candidate.exists() and _candidate_str not in sys.path:
sys.path.insert(0, _candidate_str)
from enhanced_replica.cli_args import add_base_args
from enhanced_replica.data_utils import get_ds_meta, load_dataset_manifest, load_dataset_splits, SPLITS, validate_schema
from enhanced_replica.io_utils import create_run_context, ensure_dir, write_csv, write_json, write_run_manifest, write_run_report, write_yaml_minimal
def summarize_df(df: pd.DataFrame) -> dict:
"""Return basic statistics for a split DataFrame."""
if df.empty:
return {"rows": 0, "mean_length": 0.0, "min_length": 0, "max_length": 0, "median_length": 0.0}
lengths = df["length_char"].astype(float)
return {
"rows": int(len(df)),
"mean_length": float(lengths.mean()),
"min_length": int(lengths.min()),
"max_length": int(lengths.max()),
"median_length": float(lengths.median()),
}
def run_e00(args: argparse.Namespace) -> dict:
ctx = create_run_context(eid="E00", output_root=Path(args.output_root), run_name=args.run_name)
ctx.log(f"E00 start | run_name={ctx.run_name}")
manifest = load_dataset_manifest()
if args.dataset_ids:
ds_ids = [s.strip() for s in args.dataset_ids.split(",") if s.strip()]
else:
ds_ids = sorted(manifest.keys())
ctx.log(f"Datasets to validate: {ds_ids}")
summary_rows = []
detail_rows = []
schema_errors = []
overall_ok = True
for ds_id in ds_ids:
info = manifest[ds_id]
ds_meta = get_ds_meta(manifest, ds_id)
# 1. Load splits (this tests _common.data_utils.load_dataset_splits)
try:
splits = load_dataset_splits(ds_meta)
except Exception as e:
ctx.log(f"ERROR loading {ds_id}: {e}")
schema_errors.append({"dataset_id": ds_id, "split": "all", "error": f"load_failed: {e}"})
overall_ok = False
continue
# 2. Validate schema per split
ds_ok = True
for sp in SPLITS:
df = splits[sp]
errs = validate_schema(df)
if errs:
ds_ok = False
overall_ok = False
for err in errs:
schema_errors.append({"dataset_id": ds_id, "split": sp, "error": err})
ctx.log(f" SCHEMA ERROR {ds_id}/{sp}: {err}")
# 3. Summarize
stats = summarize_df(df)
label_counts = Counter(df["label"].tolist()) if not df.empty else Counter()
source_counts = Counter(df["source"].tolist()) if not df.empty else Counter()
detail_rows.append({
"dataset_id": ds_id,
"split": sp,
**stats,
"label_0": int(label_counts.get(0, 0)),
"label_1": int(label_counts.get(1, 0)),
"sources": "; ".join([f"{k}={v}" for k, v in sorted(source_counts.items())]),
})
if ds_ok:
ctx.log(f" {ds_id}: OK | total_rows={sum(len(splits[sp]) for sp in SPLITS)}")
else:
ctx.log(f" {ds_id}: SCHEMA ERRORS FOUND")
# Overall per-dataset summary
total_rows = sum(len(splits[sp]) for sp in SPLITS)
summary_rows.append({
"dataset_id": ds_id,
"tier": info.get("tier", ""),
"total_rows": total_rows,
"train": len(splits["train"]),
"dev": len(splits["dev"]),
"test": len(splits["test"]),
"status": "ok" if ds_ok else "schema_error",
})
# 4. Write outputs
write_csv(ctx.run_dir / "dataset_manifest.csv", summary_rows)
write_csv(ctx.run_dir / "dataset_detail.csv", detail_rows)
schema_report = {
"overall_ok": overall_ok,
"datasets_checked": len(ds_ids),
"errors": schema_errors,
}
write_json(ctx.run_dir / "schema_report.json", schema_report)
# 5. Write run artifacts
config = {"seed": args.seed, "smoke": args.smoke}
if args.dataset_ids:
config["dataset_ids"] = ds_ids
write_yaml_minimal(ctx.config_file, config)
result = {
"overall_ok": overall_ok,
"datasets_checked": len(ds_ids),
"error_count": len(schema_errors),
}
status = "success" if overall_ok else "schema_error"
write_run_manifest(ctx, status=status, payload=result)
write_run_report(ctx, status=status, config=config, payload=result)
ctx.log(f"E00 complete | status={status} | errors={len(schema_errors)}")
return result
def main() -> int:
parser = argparse.ArgumentParser(description="E00 Dataset validation and distribution summary")
parser = add_base_args(parser)
parser.add_argument("--dataset_ids", default=None, help="Comma-separated dataset IDs to validate (default: all datasets in manifest)")
args = parser.parse_args()
try:
run_e00(args)
return 0
except Exception as e:
print(f"ERROR: {e}")
raise
if __name__ == "__main__":
raise SystemExit(main())