File size: 6,318 Bytes
4a0f6a5
 
 
 
 
 
 
 
 
6b6f412
4a0f6a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b6f412
4a0f6a5
 
 
 
 
 
 
6b6f412
4a0f6a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from __future__ import annotations

import json
import warnings
from pathlib import Path
from typing import Dict, List, Sequence

import pandas as pd

from .io_utils import read_json, resolve_repo_path


DEFAULT_REQUIRED_FIELDS = ["record_id", "text", "label", "source", "split", "length_char", "topic", "model_slug"]
SPLITS = ["train", "dev", "test"]


def get_required_fields(manifest: dict | None = None) -> List[str]:
    """Return required fields, preferring manifest metadata if present."""
    if manifest is not None:
        meta = manifest.get("__meta__") or manifest.get("_meta")
        if isinstance(meta, dict) and "required_fields" in meta:
            return list(meta["required_fields"])
    return list(DEFAULT_REQUIRED_FIELDS)


def load_dataset_manifest(manifest_file: Path | None = None) -> dict:
    """Load the central dataset_manifests.json."""
    if manifest_file is None:
        from .io_utils import DEFAULT_MANIFEST_FILE
        manifest_file = DEFAULT_MANIFEST_FILE
    return read_json(resolve_repo_path(manifest_file))


def get_ds_meta(manifest: dict, ds_id: str) -> dict:
    """Extract dataset metadata for a given ds_id (e.g. 'DS01')."""
    if ds_id not in manifest:
        raise KeyError(f"{ds_id} not found in dataset manifest")
    info = manifest[ds_id]
    ds_dir = resolve_repo_path(info["dataset_dir"])
    out = {
        "dataset_id": info["dataset_id"],
        "dataset_dir": str(ds_dir),
    }
    # Prefer explicit split mapping from manifest; fall back to default jsonl names.
    split_map = info.get("splits") if isinstance(info.get("splits"), dict) else {}
    for sp in SPLITS:
        filename = split_map.get(sp, f"{sp}.jsonl")
        out[sp] = ds_dir / filename
    return out


def load_jsonl(path: Path, warn_on_error: bool = True) -> List[dict]:
    rows: List[dict] = []
    bad_lines: List[int] = []
    with path.open("r", encoding="utf-8") as f:
        for lineno, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue
            try:
                rows.append(json.loads(line))
            except json.JSONDecodeError:
                bad_lines.append(lineno)
    if bad_lines and warn_on_error:
        warnings.warn(f"Skipped {len(bad_lines)} malformed line(s) in {path}: lines {bad_lines[:10]}{'...' if len(bad_lines) > 10 else ''}")
    return rows


def load_split_df(path: Path, required_fields: Sequence[str] | None = None) -> pd.DataFrame:
    rows = load_jsonl(path)
    df = pd.DataFrame(rows)
    cols = list(required_fields) if required_fields is not None else DEFAULT_REQUIRED_FIELDS
    if df.empty:
        return pd.DataFrame(columns=cols)
    return df


def load_dataset_splits(ds_meta: dict, required_fields: Sequence[str] | None = None) -> Dict[str, pd.DataFrame]:
    out: Dict[str, pd.DataFrame] = {}
    for sp in SPLITS:
        p = Path(ds_meta[sp])
        out[sp] = load_split_df(p, required_fields=required_fields)
    return out


def validate_schema(df: pd.DataFrame, required_fields: Sequence[str] | None = None) -> List[str]:
    errors: List[str] = []
    fields = list(required_fields) if required_fields is not None else DEFAULT_REQUIRED_FIELDS
    missing = [k for k in fields if k not in df.columns]
    if missing:
        errors.append(f"missing_fields={missing}")
        return errors
    if not df["label"].isin([0, 1]).all():
        errors.append("label_out_of_range")
    if not df["split"].isin(SPLITS).all():
        errors.append("split_out_of_range")
    if not df.empty:
        try:
            numeric = pd.to_numeric(df["length_char"], errors="coerce")
            if numeric.isna().any() or ((numeric % 1) != 0).any():
                errors.append("length_char_not_int")
        except Exception:
            errors.append("length_char_not_int")
    if (df["text"].astype(str).str.len() == 0).any():
        errors.append("empty_text")
    if df["record_id"].duplicated().any():
        errors.append("duplicate_record_id")
    return errors


def normalize_minimal_df(
    df: pd.DataFrame,
    source: str = "unknown",
    split: str = "train",
    topic: str = "unknown",
    model_slug: str = "unknown",
    record_id_prefix: str = "auto",
) -> pd.DataFrame:
    """将最小 DataFrame(至少包含 text 和 label)升维为标准 8 字段 schema。

    此函数仅供数据集 builder 脚本在数据预处理阶段调用,**不应在实验脚本中**使用。
    8 字段 schema 仍是项目核心契约,实验脚本应始终消费已经规范化的 JSONL。
    """
    import hashlib

    if "text" not in df.columns or "label" not in df.columns:
        raise ValueError("normalize_minimal_df requires at least 'text' and 'label' columns")

    out = df.copy()

    # Ensure label is integer 0/1
    out["label"] = out["label"].astype(int)

    # Auto-generate record_id if missing (md5 of text to keep it deterministic)
    if "record_id" not in out.columns:
        def _md5(x: str) -> str:
            return hashlib.md5(str(x).encode("utf-8")).hexdigest()[:16]
        out["record_id"] = out["text"].astype(str).apply(_md5)
        if record_id_prefix != "auto":
            out["record_id"] = record_id_prefix + "_" + out["record_id"]

    # Auto-compute length_char if missing
    if "length_char" not in out.columns:
        out["length_char"] = out["text"].astype(str).str.len().astype(int)

    # Fill defaults for remaining fields
    if "source" not in out.columns:
        out["source"] = source
    if "split" not in out.columns:
        out["split"] = split
    if "topic" not in out.columns:
        out["topic"] = topic
    if "model_slug" not in out.columns:
        out["model_slug"] = model_slug

    # Enforce column order
    for col in DEFAULT_REQUIRED_FIELDS:
        if col not in out.columns:
            out[col] = None
    return out[DEFAULT_REQUIRED_FIELDS].copy()


def merge_predictions(base: pd.DataFrame, pred_df: pd.DataFrame, score_col: str, pred_col: str) -> pd.DataFrame:
    cols = ["record_id", "split", "label", "length_char", "source", score_col, pred_col]
    merged = base[["record_id", "split", "label", "length_char", "source"]].merge(
        pred_df[["record_id", score_col, pred_col]], on="record_id", how="left", validate="one_to_one"
    )
    return merged[cols]