LUCIFerace's picture
Add files using upload-large-folder tool
b09b696 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""E03: Rule-based module evaluation.
This script evaluates the pattern-based rule engine on DS01 dev/test,
producing per-sample hit records and per-rule precision statistics.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
import numpy as np
import pandas as pd
# Allow imports from source root package
REPO_ROOT = Path(__file__).resolve()
while REPO_ROOT != REPO_ROOT.parent and not (REPO_ROOT / "src").exists():
REPO_ROOT = REPO_ROOT.parent
for _candidate in (REPO_ROOT, REPO_ROOT / "src"):
_candidate_str = str(_candidate)
if _candidate.exists() and _candidate_str not in sys.path:
sys.path.insert(0, _candidate_str)
from enhanced_replica.cli_args import add_base_args, resolve_arg, setup_logging
from enhanced_replica.data_utils import get_ds_meta, load_dataset_manifest, load_dataset_splits, SPLITS
from enhanced_replica.io_utils import (
create_run_context,
prepare_primary_ds_for_train,
save_pred_files,
write_csv,
write_run_manifest,
write_run_report,
write_yaml_minimal,
)
from enhanced_replica.rules import _DEFAULT_ENGINE, PatternRule, RuleEngine
def evaluate_rules_on_split(df: pd.DataFrame, engine: RuleEngine) -> pd.DataFrame:
"""Apply rule engine to a DataFrame and return expanded hit records."""
records = []
for _, row in df.iterrows():
text = str(row["text"])
label = int(row["label"])
score, hit_map, _ = engine.score(text)
explain = engine.explain(text)
# Flatten: one row per rule hit, plus one summary row even if no hits
if not explain:
records.append({
"record_id": row["record_id"],
"split": row.get("split", ""),
"label": label,
"length_char": row.get("length_char", len(text)),
"rule_score": score,
"rid": "",
"rule_name": "",
"category": "",
"match_count": 0,
"weight": 0.0,
"matches": "",
})
else:
for item in explain:
records.append({
"record_id": row["record_id"],
"split": row.get("split", ""),
"label": label,
"length_char": row.get("length_char", len(text)),
"rule_score": score,
"rid": item["rid"],
"rule_name": item["name"],
"category": item["category"],
"match_count": item["count"],
"weight": item["weight"],
"matches": " | ".join(str(m) for m in item["matches"]),
})
return pd.DataFrame(records)
def compute_rule_precision(hit_df: pd.DataFrame, rules: list[PatternRule]) -> list[dict]:
"""Compute per-rule precision, recall, and hit rates.
Computes both match-level precision (legacy) and sample-level precision
to guard against extreme long-text outliers skewing the metric.
"""
rows = []
for r in rules:
if not r.enabled:
continue
rid = r.rid
sub = hit_df[hit_df["rid"] == rid]
total_hits = int(sub["match_count"].sum())
tp = int(sub[sub["label"] == 1]["match_count"].sum())
fp = int(sub[sub["label"] == 0]["match_count"].sum())
# Precision at match level (legacy)
precision_match = tp / (tp + fp) if (tp + fp) > 0 else 0.0
# Unique samples that fired this rule
n_ai_fired = int(sub[sub["label"] == 1]["record_id"].nunique())
n_human_fired = int(sub[sub["label"] == 0]["record_id"].nunique())
# Sample-level precision: each sample counts once regardless of hit count
precision_sample = n_ai_fired / (n_ai_fired + n_human_fired) if (n_ai_fired + n_human_fired) > 0 else 0.0
# Denominators: total samples in each class for this split set
total_ai = int(hit_df[hit_df["label"] == 1]["record_id"].nunique())
total_human = int(hit_df[hit_df["label"] == 0]["record_id"].nunique())
recall_ai = n_ai_fired / total_ai if total_ai > 0 else 0.0
fpr_human = n_human_fired / total_human if total_human > 0 else 0.0
rows.append({
"rid": rid,
"rule_name": r.name,
"category": r.category,
"weight": r.weight,
"total_match_hits": total_hits,
"tp_matches": tp,
"fp_matches": fp,
"precision_match": round(precision_match, 4),
"precision_sample": round(precision_sample, 4),
"recall_ai": round(recall_ai, 4),
"fpr_human": round(fpr_human, 4),
"ai_samples_fired": n_ai_fired,
"human_samples_fired": n_human_fired,
"total_ai_samples": total_ai,
"total_human_samples": total_human,
})
return rows
def run_e03(args: argparse.Namespace) -> dict:
setup_logging(args.log_level)
logger = logging.getLogger("E03")
ctx = create_run_context(eid="E03", output_root=Path(args.output_root), run_name=args.run_name)
logger.info(f"E03 Rule module eval start | run_name={ctx.run_name} | smoke={args.smoke}")
manifest = load_dataset_manifest(Path(args.manifest_file))
ds_meta = get_ds_meta(manifest, "DS01")
logger.info(f"Dataset: {ds_meta['dataset_id']} | dir={ds_meta['dataset_dir']}")
splits = prepare_primary_ds_for_train(ds_meta, smoke=args.smoke, seed=args.seed)
for sp in SPLITS:
logger.info(f" {sp}: {len(splits[sp])} rows")
engine = _DEFAULT_ENGINE
logger.info(f"Active rules: {sum(1 for r in engine.rules if r.enabled)}")
all_hits = []
pred_splits = {}
split_metrics = {}
for sp in SPLITS:
if splits[sp].empty:
logger.warning(f"{sp} is empty, skipping.")
continue
hit_df = evaluate_rules_on_split(splits[sp], engine)
hit_df["split"] = sp
all_hits.append(hit_df)
# Build prediction split for downstream consumers (E04/E05/E07)
pred_df = splits[sp][["record_id", "text", "label", "source", "split", "length_char"]].copy()
pred_df["score"] = pred_df["text"].apply(lambda t: engine.score(t)[0])
pred_df["pred"] = (pred_df["score"] >= 0.5).astype(int)
pred_splits[sp] = pred_df
# Aggregate per-split quick stats
fired = hit_df[hit_df["rid"] != ""]
n_fired = fired["record_id"].nunique()
n_total = len(splits[sp])
split_metrics[sp] = {"samples": n_total, "fired": n_fired, "coverage": round(n_fired / n_total, 4) if n_total else 0.0}
logger.info(f"{sp} coverage: {n_fired}/{n_total} ({split_metrics[sp]['coverage']:.2%})")
if not all_hits:
raise ValueError("No dev/test data available for rule evaluation.")
combined_hits = pd.concat(all_hits, ignore_index=True)
rule_precisions = compute_rule_precision(combined_hits, engine.rules)
# Save outputs
save_pred_files(ctx, pred_splits)
combined_hits.to_csv(ctx.run_dir / "rule_hits.csv", index=False, encoding="utf-8-sig")
write_csv(ctx.run_dir / "rule_precision.csv", rule_precisions)
# Determine determinative candidates (high sample-level precision + reasonable coverage)
determinative_candidates = [
r for r in rule_precisions
if r["precision_sample"] >= 0.85 and r["recall_ai"] >= 0.05
]
logger.info(f"Determinative candidates (precision_sample>=0.85 & recall_ai>=0.05): {len(determinative_candidates)}")
for c in determinative_candidates:
logger.info(f" - {c['rid']}: precision_sample={c['precision_sample']}, recall_ai={c['recall_ai']}")
config = {
"seed": args.seed,
"smoke": args.smoke,
"active_rules": sum(1 for r in engine.rules if r.enabled),
"determinative_candidates": len(determinative_candidates),
}
write_yaml_minimal(ctx.config_file, config)
result = {
"dev_coverage": split_metrics.get("dev", {}).get("coverage", 0.0),
"test_coverage": split_metrics.get("test", {}).get("coverage", 0.0),
"determinative_candidates": len(determinative_candidates),
}
write_run_manifest(ctx, status="success", payload=result)
write_run_report(ctx, status="success", config=config, payload=result)
logger.info("E03 complete.")
return result
def main() -> int:
parser = argparse.ArgumentParser(description="E03 Rule module evaluation")
parser = add_base_args(parser)
args = parser.parse_args()
try:
run_e03(args)
return 0
except Exception as e:
logging.getLogger("E03").error(f"ERROR: {e}", exc_info=True)
raise
if __name__ == "__main__":
raise SystemExit(main())