LUCIFerace's picture
Add files using upload-large-folder tool
6b6f412 verified
from __future__ import annotations
import json
import warnings
from pathlib import Path
from typing import Dict, List, Sequence
import pandas as pd
from .io_utils import read_json, resolve_repo_path
DEFAULT_REQUIRED_FIELDS = ["record_id", "text", "label", "source", "split", "length_char", "topic", "model_slug"]
SPLITS = ["train", "dev", "test"]
def get_required_fields(manifest: dict | None = None) -> List[str]:
"""Return required fields, preferring manifest metadata if present."""
if manifest is not None:
meta = manifest.get("__meta__") or manifest.get("_meta")
if isinstance(meta, dict) and "required_fields" in meta:
return list(meta["required_fields"])
return list(DEFAULT_REQUIRED_FIELDS)
def load_dataset_manifest(manifest_file: Path | None = None) -> dict:
"""Load the central dataset_manifests.json."""
if manifest_file is None:
from .io_utils import DEFAULT_MANIFEST_FILE
manifest_file = DEFAULT_MANIFEST_FILE
return read_json(resolve_repo_path(manifest_file))
def get_ds_meta(manifest: dict, ds_id: str) -> dict:
"""Extract dataset metadata for a given ds_id (e.g. 'DS01')."""
if ds_id not in manifest:
raise KeyError(f"{ds_id} not found in dataset manifest")
info = manifest[ds_id]
ds_dir = resolve_repo_path(info["dataset_dir"])
out = {
"dataset_id": info["dataset_id"],
"dataset_dir": str(ds_dir),
}
# Prefer explicit split mapping from manifest; fall back to default jsonl names.
split_map = info.get("splits") if isinstance(info.get("splits"), dict) else {}
for sp in SPLITS:
filename = split_map.get(sp, f"{sp}.jsonl")
out[sp] = ds_dir / filename
return out
def load_jsonl(path: Path, warn_on_error: bool = True) -> List[dict]:
rows: List[dict] = []
bad_lines: List[int] = []
with path.open("r", encoding="utf-8") as f:
for lineno, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
try:
rows.append(json.loads(line))
except json.JSONDecodeError:
bad_lines.append(lineno)
if bad_lines and warn_on_error:
warnings.warn(f"Skipped {len(bad_lines)} malformed line(s) in {path}: lines {bad_lines[:10]}{'...' if len(bad_lines) > 10 else ''}")
return rows
def load_split_df(path: Path, required_fields: Sequence[str] | None = None) -> pd.DataFrame:
rows = load_jsonl(path)
df = pd.DataFrame(rows)
cols = list(required_fields) if required_fields is not None else DEFAULT_REQUIRED_FIELDS
if df.empty:
return pd.DataFrame(columns=cols)
return df
def load_dataset_splits(ds_meta: dict, required_fields: Sequence[str] | None = None) -> Dict[str, pd.DataFrame]:
out: Dict[str, pd.DataFrame] = {}
for sp in SPLITS:
p = Path(ds_meta[sp])
out[sp] = load_split_df(p, required_fields=required_fields)
return out
def validate_schema(df: pd.DataFrame, required_fields: Sequence[str] | None = None) -> List[str]:
errors: List[str] = []
fields = list(required_fields) if required_fields is not None else DEFAULT_REQUIRED_FIELDS
missing = [k for k in fields if k not in df.columns]
if missing:
errors.append(f"missing_fields={missing}")
return errors
if not df["label"].isin([0, 1]).all():
errors.append("label_out_of_range")
if not df["split"].isin(SPLITS).all():
errors.append("split_out_of_range")
if not df.empty:
try:
numeric = pd.to_numeric(df["length_char"], errors="coerce")
if numeric.isna().any() or ((numeric % 1) != 0).any():
errors.append("length_char_not_int")
except Exception:
errors.append("length_char_not_int")
if (df["text"].astype(str).str.len() == 0).any():
errors.append("empty_text")
if df["record_id"].duplicated().any():
errors.append("duplicate_record_id")
return errors
def normalize_minimal_df(
df: pd.DataFrame,
source: str = "unknown",
split: str = "train",
topic: str = "unknown",
model_slug: str = "unknown",
record_id_prefix: str = "auto",
) -> pd.DataFrame:
"""将最小 DataFrame(至少包含 text 和 label)升维为标准 8 字段 schema。
此函数仅供数据集 builder 脚本在数据预处理阶段调用,**不应在实验脚本中**使用。
8 字段 schema 仍是项目核心契约,实验脚本应始终消费已经规范化的 JSONL。
"""
import hashlib
if "text" not in df.columns or "label" not in df.columns:
raise ValueError("normalize_minimal_df requires at least 'text' and 'label' columns")
out = df.copy()
# Ensure label is integer 0/1
out["label"] = out["label"].astype(int)
# Auto-generate record_id if missing (md5 of text to keep it deterministic)
if "record_id" not in out.columns:
def _md5(x: str) -> str:
return hashlib.md5(str(x).encode("utf-8")).hexdigest()[:16]
out["record_id"] = out["text"].astype(str).apply(_md5)
if record_id_prefix != "auto":
out["record_id"] = record_id_prefix + "_" + out["record_id"]
# Auto-compute length_char if missing
if "length_char" not in out.columns:
out["length_char"] = out["text"].astype(str).str.len().astype(int)
# Fill defaults for remaining fields
if "source" not in out.columns:
out["source"] = source
if "split" not in out.columns:
out["split"] = split
if "topic" not in out.columns:
out["topic"] = topic
if "model_slug" not in out.columns:
out["model_slug"] = model_slug
# Enforce column order
for col in DEFAULT_REQUIRED_FIELDS:
if col not in out.columns:
out[col] = None
return out[DEFAULT_REQUIRED_FIELDS].copy()
def merge_predictions(base: pd.DataFrame, pred_df: pd.DataFrame, score_col: str, pred_col: str) -> pd.DataFrame:
cols = ["record_id", "split", "label", "length_char", "source", score_col, pred_col]
merged = base[["record_id", "split", "label", "length_char", "source"]].merge(
pred_df[["record_id", score_col, pred_col]], on="record_id", how="left", validate="one_to_one"
)
return merged[cols]