#!/usr/bin/env python3 # -*- coding: utf-8 -*- """E00: Dataset schema validation and distribution summary. This is essentially a smoke test for the 公共模块 layer and the dataset directory. It loads every dataset, validates schema, and writes a rollup report. """ from __future__ import annotations import argparse import json import sys from collections import Counter from pathlib import Path import numpy as np import pandas as pd REPO_ROOT = Path(__file__).resolve() while REPO_ROOT != REPO_ROOT.parent and not (REPO_ROOT / "src").exists(): REPO_ROOT = REPO_ROOT.parent for _candidate in (REPO_ROOT, REPO_ROOT / "src"): _candidate_str = str(_candidate) if _candidate.exists() and _candidate_str not in sys.path: sys.path.insert(0, _candidate_str) from enhanced_replica.cli_args import add_base_args from enhanced_replica.data_utils import get_ds_meta, load_dataset_manifest, load_dataset_splits, SPLITS, validate_schema from enhanced_replica.io_utils import create_run_context, ensure_dir, write_csv, write_json, write_run_manifest, write_run_report, write_yaml_minimal def summarize_df(df: pd.DataFrame) -> dict: """Return basic statistics for a split DataFrame.""" if df.empty: return {"rows": 0, "mean_length": 0.0, "min_length": 0, "max_length": 0, "median_length": 0.0} lengths = df["length_char"].astype(float) return { "rows": int(len(df)), "mean_length": float(lengths.mean()), "min_length": int(lengths.min()), "max_length": int(lengths.max()), "median_length": float(lengths.median()), } def run_e00(args: argparse.Namespace) -> dict: ctx = create_run_context(eid="E00", output_root=Path(args.output_root), run_name=args.run_name) ctx.log(f"E00 start | run_name={ctx.run_name}") manifest = load_dataset_manifest() if args.dataset_ids: ds_ids = [s.strip() for s in args.dataset_ids.split(",") if s.strip()] else: ds_ids = sorted(manifest.keys()) ctx.log(f"Datasets to validate: {ds_ids}") summary_rows = [] detail_rows = [] schema_errors = [] overall_ok = True for ds_id in ds_ids: info = manifest[ds_id] ds_meta = get_ds_meta(manifest, ds_id) # 1. Load splits (this tests _common.data_utils.load_dataset_splits) try: splits = load_dataset_splits(ds_meta) except Exception as e: ctx.log(f"ERROR loading {ds_id}: {e}") schema_errors.append({"dataset_id": ds_id, "split": "all", "error": f"load_failed: {e}"}) overall_ok = False continue # 2. Validate schema per split ds_ok = True for sp in SPLITS: df = splits[sp] errs = validate_schema(df) if errs: ds_ok = False overall_ok = False for err in errs: schema_errors.append({"dataset_id": ds_id, "split": sp, "error": err}) ctx.log(f" SCHEMA ERROR {ds_id}/{sp}: {err}") # 3. Summarize stats = summarize_df(df) label_counts = Counter(df["label"].tolist()) if not df.empty else Counter() source_counts = Counter(df["source"].tolist()) if not df.empty else Counter() detail_rows.append({ "dataset_id": ds_id, "split": sp, **stats, "label_0": int(label_counts.get(0, 0)), "label_1": int(label_counts.get(1, 0)), "sources": "; ".join([f"{k}={v}" for k, v in sorted(source_counts.items())]), }) if ds_ok: ctx.log(f" {ds_id}: OK | total_rows={sum(len(splits[sp]) for sp in SPLITS)}") else: ctx.log(f" {ds_id}: SCHEMA ERRORS FOUND") # Overall per-dataset summary total_rows = sum(len(splits[sp]) for sp in SPLITS) summary_rows.append({ "dataset_id": ds_id, "tier": info.get("tier", ""), "total_rows": total_rows, "train": len(splits["train"]), "dev": len(splits["dev"]), "test": len(splits["test"]), "status": "ok" if ds_ok else "schema_error", }) # 4. Write outputs write_csv(ctx.run_dir / "dataset_manifest.csv", summary_rows) write_csv(ctx.run_dir / "dataset_detail.csv", detail_rows) schema_report = { "overall_ok": overall_ok, "datasets_checked": len(ds_ids), "errors": schema_errors, } write_json(ctx.run_dir / "schema_report.json", schema_report) # 5. Write run artifacts config = {"seed": args.seed, "smoke": args.smoke} if args.dataset_ids: config["dataset_ids"] = ds_ids write_yaml_minimal(ctx.config_file, config) result = { "overall_ok": overall_ok, "datasets_checked": len(ds_ids), "error_count": len(schema_errors), } status = "success" if overall_ok else "schema_error" write_run_manifest(ctx, status=status, payload=result) write_run_report(ctx, status=status, config=config, payload=result) ctx.log(f"E00 complete | status={status} | errors={len(schema_errors)}") return result def main() -> int: parser = argparse.ArgumentParser(description="E00 Dataset validation and distribution summary") parser = add_base_args(parser) parser.add_argument("--dataset_ids", default=None, help="Comma-separated dataset IDs to validate (default: all datasets in manifest)") args = parser.parse_args() try: run_e00(args) return 0 except Exception as e: print(f"ERROR: {e}") raise if __name__ == "__main__": raise SystemExit(main())