from __future__ import annotations import json import warnings from pathlib import Path from typing import Dict, List, Sequence import pandas as pd from .io_utils import read_json, resolve_repo_path DEFAULT_REQUIRED_FIELDS = ["record_id", "text", "label", "source", "split", "length_char", "topic", "model_slug"] SPLITS = ["train", "dev", "test"] def get_required_fields(manifest: dict | None = None) -> List[str]: """Return required fields, preferring manifest metadata if present.""" if manifest is not None: meta = manifest.get("__meta__") or manifest.get("_meta") if isinstance(meta, dict) and "required_fields" in meta: return list(meta["required_fields"]) return list(DEFAULT_REQUIRED_FIELDS) def load_dataset_manifest(manifest_file: Path | None = None) -> dict: """Load the central dataset_manifests.json.""" if manifest_file is None: from .io_utils import DEFAULT_MANIFEST_FILE manifest_file = DEFAULT_MANIFEST_FILE return read_json(resolve_repo_path(manifest_file)) def get_ds_meta(manifest: dict, ds_id: str) -> dict: """Extract dataset metadata for a given ds_id (e.g. 'DS01').""" if ds_id not in manifest: raise KeyError(f"{ds_id} not found in dataset manifest") info = manifest[ds_id] ds_dir = resolve_repo_path(info["dataset_dir"]) out = { "dataset_id": info["dataset_id"], "dataset_dir": str(ds_dir), } # Prefer explicit split mapping from manifest; fall back to default jsonl names. split_map = info.get("splits") if isinstance(info.get("splits"), dict) else {} for sp in SPLITS: filename = split_map.get(sp, f"{sp}.jsonl") out[sp] = ds_dir / filename return out def load_jsonl(path: Path, warn_on_error: bool = True) -> List[dict]: rows: List[dict] = [] bad_lines: List[int] = [] with path.open("r", encoding="utf-8") as f: for lineno, line in enumerate(f, start=1): line = line.strip() if not line: continue try: rows.append(json.loads(line)) except json.JSONDecodeError: bad_lines.append(lineno) if bad_lines and warn_on_error: warnings.warn(f"Skipped {len(bad_lines)} malformed line(s) in {path}: lines {bad_lines[:10]}{'...' if len(bad_lines) > 10 else ''}") return rows def load_split_df(path: Path, required_fields: Sequence[str] | None = None) -> pd.DataFrame: rows = load_jsonl(path) df = pd.DataFrame(rows) cols = list(required_fields) if required_fields is not None else DEFAULT_REQUIRED_FIELDS if df.empty: return pd.DataFrame(columns=cols) return df def load_dataset_splits(ds_meta: dict, required_fields: Sequence[str] | None = None) -> Dict[str, pd.DataFrame]: out: Dict[str, pd.DataFrame] = {} for sp in SPLITS: p = Path(ds_meta[sp]) out[sp] = load_split_df(p, required_fields=required_fields) return out def validate_schema(df: pd.DataFrame, required_fields: Sequence[str] | None = None) -> List[str]: errors: List[str] = [] fields = list(required_fields) if required_fields is not None else DEFAULT_REQUIRED_FIELDS missing = [k for k in fields if k not in df.columns] if missing: errors.append(f"missing_fields={missing}") return errors if not df["label"].isin([0, 1]).all(): errors.append("label_out_of_range") if not df["split"].isin(SPLITS).all(): errors.append("split_out_of_range") if not df.empty: try: numeric = pd.to_numeric(df["length_char"], errors="coerce") if numeric.isna().any() or ((numeric % 1) != 0).any(): errors.append("length_char_not_int") except Exception: errors.append("length_char_not_int") if (df["text"].astype(str).str.len() == 0).any(): errors.append("empty_text") if df["record_id"].duplicated().any(): errors.append("duplicate_record_id") return errors def normalize_minimal_df( df: pd.DataFrame, source: str = "unknown", split: str = "train", topic: str = "unknown", model_slug: str = "unknown", record_id_prefix: str = "auto", ) -> pd.DataFrame: """将最小 DataFrame(至少包含 text 和 label)升维为标准 8 字段 schema。 此函数仅供数据集 builder 脚本在数据预处理阶段调用,**不应在实验脚本中**使用。 8 字段 schema 仍是项目核心契约,实验脚本应始终消费已经规范化的 JSONL。 """ import hashlib if "text" not in df.columns or "label" not in df.columns: raise ValueError("normalize_minimal_df requires at least 'text' and 'label' columns") out = df.copy() # Ensure label is integer 0/1 out["label"] = out["label"].astype(int) # Auto-generate record_id if missing (md5 of text to keep it deterministic) if "record_id" not in out.columns: def _md5(x: str) -> str: return hashlib.md5(str(x).encode("utf-8")).hexdigest()[:16] out["record_id"] = out["text"].astype(str).apply(_md5) if record_id_prefix != "auto": out["record_id"] = record_id_prefix + "_" + out["record_id"] # Auto-compute length_char if missing if "length_char" not in out.columns: out["length_char"] = out["text"].astype(str).str.len().astype(int) # Fill defaults for remaining fields if "source" not in out.columns: out["source"] = source if "split" not in out.columns: out["split"] = split if "topic" not in out.columns: out["topic"] = topic if "model_slug" not in out.columns: out["model_slug"] = model_slug # Enforce column order for col in DEFAULT_REQUIRED_FIELDS: if col not in out.columns: out[col] = None return out[DEFAULT_REQUIRED_FIELDS].copy() def merge_predictions(base: pd.DataFrame, pred_df: pd.DataFrame, score_col: str, pred_col: str) -> pd.DataFrame: cols = ["record_id", "split", "label", "length_char", "source", score_col, pred_col] merged = base[["record_id", "split", "label", "length_char", "source"]].merge( pred_df[["record_id", score_col, pred_col]], on="record_id", how="left", validate="one_to_one" ) return merged[cols]