File size: 5,722 Bytes
4a0f6a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b6f412
4a0f6a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b6f412
4a0f6a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""E00: Dataset schema validation and distribution summary.

This is essentially a smoke test for the 公共模块 layer and the dataset directory.
It loads every dataset, validates schema, and writes a rollup report.
"""
from __future__ import annotations

import argparse
import json
import sys
from collections import Counter
from pathlib import Path

import numpy as np
import pandas as pd

REPO_ROOT = Path(__file__).resolve()
while REPO_ROOT != REPO_ROOT.parent and not (REPO_ROOT / "src").exists():
    REPO_ROOT = REPO_ROOT.parent
for _candidate in (REPO_ROOT, REPO_ROOT / "src"):
    _candidate_str = str(_candidate)
    if _candidate.exists() and _candidate_str not in sys.path:
        sys.path.insert(0, _candidate_str)

from enhanced_replica.cli_args import add_base_args
from enhanced_replica.data_utils import get_ds_meta, load_dataset_manifest, load_dataset_splits, SPLITS, validate_schema
from enhanced_replica.io_utils import create_run_context, ensure_dir, write_csv, write_json, write_run_manifest, write_run_report, write_yaml_minimal


def summarize_df(df: pd.DataFrame) -> dict:
    """Return basic statistics for a split DataFrame."""
    if df.empty:
        return {"rows": 0, "mean_length": 0.0, "min_length": 0, "max_length": 0, "median_length": 0.0}
    lengths = df["length_char"].astype(float)
    return {
        "rows": int(len(df)),
        "mean_length": float(lengths.mean()),
        "min_length": int(lengths.min()),
        "max_length": int(lengths.max()),
        "median_length": float(lengths.median()),
    }


def run_e00(args: argparse.Namespace) -> dict:
    ctx = create_run_context(eid="E00", output_root=Path(args.output_root), run_name=args.run_name)
    ctx.log(f"E00 start | run_name={ctx.run_name}")

    manifest = load_dataset_manifest()
    if args.dataset_ids:
        ds_ids = [s.strip() for s in args.dataset_ids.split(",") if s.strip()]
    else:
        ds_ids = sorted(manifest.keys())
    ctx.log(f"Datasets to validate: {ds_ids}")

    summary_rows = []
    detail_rows = []
    schema_errors = []
    overall_ok = True

    for ds_id in ds_ids:
        info = manifest[ds_id]
        ds_meta = get_ds_meta(manifest, ds_id)

        # 1. Load splits (this tests _common.data_utils.load_dataset_splits)
        try:
            splits = load_dataset_splits(ds_meta)
        except Exception as e:
            ctx.log(f"ERROR loading {ds_id}: {e}")
            schema_errors.append({"dataset_id": ds_id, "split": "all", "error": f"load_failed: {e}"})
            overall_ok = False
            continue

        # 2. Validate schema per split
        ds_ok = True
        for sp in SPLITS:
            df = splits[sp]
            errs = validate_schema(df)
            if errs:
                ds_ok = False
                overall_ok = False
                for err in errs:
                    schema_errors.append({"dataset_id": ds_id, "split": sp, "error": err})
                    ctx.log(f"  SCHEMA ERROR {ds_id}/{sp}: {err}")

            # 3. Summarize
            stats = summarize_df(df)
            label_counts = Counter(df["label"].tolist()) if not df.empty else Counter()
            source_counts = Counter(df["source"].tolist()) if not df.empty else Counter()

            detail_rows.append({
                "dataset_id": ds_id,
                "split": sp,
                **stats,
                "label_0": int(label_counts.get(0, 0)),
                "label_1": int(label_counts.get(1, 0)),
                "sources": "; ".join([f"{k}={v}" for k, v in sorted(source_counts.items())]),
            })

        if ds_ok:
            ctx.log(f"  {ds_id}: OK | total_rows={sum(len(splits[sp]) for sp in SPLITS)}")
        else:
            ctx.log(f"  {ds_id}: SCHEMA ERRORS FOUND")

        # Overall per-dataset summary
        total_rows = sum(len(splits[sp]) for sp in SPLITS)
        summary_rows.append({
            "dataset_id": ds_id,
            "tier": info.get("tier", ""),
            "total_rows": total_rows,
            "train": len(splits["train"]),
            "dev": len(splits["dev"]),
            "test": len(splits["test"]),
            "status": "ok" if ds_ok else "schema_error",
        })

    # 4. Write outputs
    write_csv(ctx.run_dir / "dataset_manifest.csv", summary_rows)
    write_csv(ctx.run_dir / "dataset_detail.csv", detail_rows)

    schema_report = {
        "overall_ok": overall_ok,
        "datasets_checked": len(ds_ids),
        "errors": schema_errors,
    }
    write_json(ctx.run_dir / "schema_report.json", schema_report)

    # 5. Write run artifacts
    config = {"seed": args.seed, "smoke": args.smoke}
    if args.dataset_ids:
        config["dataset_ids"] = ds_ids
    write_yaml_minimal(ctx.config_file, config)

    result = {
        "overall_ok": overall_ok,
        "datasets_checked": len(ds_ids),
        "error_count": len(schema_errors),
    }
    status = "success" if overall_ok else "schema_error"
    write_run_manifest(ctx, status=status, payload=result)
    write_run_report(ctx, status=status, config=config, payload=result)
    ctx.log(f"E00 complete | status={status} | errors={len(schema_errors)}")
    return result


def main() -> int:
    parser = argparse.ArgumentParser(description="E00 Dataset validation and distribution summary")
    parser = add_base_args(parser)
    parser.add_argument("--dataset_ids", default=None, help="Comma-separated dataset IDs to validate (default: all datasets in manifest)")
    args = parser.parse_args()
    try:
        run_e00(args)
        return 0
    except Exception as e:
        print(f"ERROR: {e}")
        raise


if __name__ == "__main__":
    raise SystemExit(main())