#!/usr/bin/env python3 # -*- coding: utf-8 -*- """E03: Rule-based module evaluation. This script evaluates the pattern-based rule engine on DS01 dev/test, producing per-sample hit records and per-rule precision statistics. """ from __future__ import annotations import argparse import logging import sys from pathlib import Path import numpy as np import pandas as pd # Allow imports from source root package REPO_ROOT = Path(__file__).resolve() while REPO_ROOT != REPO_ROOT.parent and not (REPO_ROOT / "src").exists(): REPO_ROOT = REPO_ROOT.parent for _candidate in (REPO_ROOT, REPO_ROOT / "src"): _candidate_str = str(_candidate) if _candidate.exists() and _candidate_str not in sys.path: sys.path.insert(0, _candidate_str) from enhanced_replica.cli_args import add_base_args, resolve_arg, setup_logging from enhanced_replica.data_utils import get_ds_meta, load_dataset_manifest, load_dataset_splits, SPLITS from enhanced_replica.io_utils import ( create_run_context, prepare_primary_ds_for_train, save_pred_files, write_csv, write_run_manifest, write_run_report, write_yaml_minimal, ) from enhanced_replica.rules import _DEFAULT_ENGINE, PatternRule, RuleEngine def evaluate_rules_on_split(df: pd.DataFrame, engine: RuleEngine) -> pd.DataFrame: """Apply rule engine to a DataFrame and return expanded hit records.""" records = [] for _, row in df.iterrows(): text = str(row["text"]) label = int(row["label"]) score, hit_map, _ = engine.score(text) explain = engine.explain(text) # Flatten: one row per rule hit, plus one summary row even if no hits if not explain: records.append({ "record_id": row["record_id"], "split": row.get("split", ""), "label": label, "length_char": row.get("length_char", len(text)), "rule_score": score, "rid": "", "rule_name": "", "category": "", "match_count": 0, "weight": 0.0, "matches": "", }) else: for item in explain: records.append({ "record_id": row["record_id"], "split": row.get("split", ""), "label": label, "length_char": row.get("length_char", len(text)), "rule_score": score, "rid": item["rid"], "rule_name": item["name"], "category": item["category"], "match_count": item["count"], "weight": item["weight"], "matches": " | ".join(str(m) for m in item["matches"]), }) return pd.DataFrame(records) def compute_rule_precision(hit_df: pd.DataFrame, rules: list[PatternRule]) -> list[dict]: """Compute per-rule precision, recall, and hit rates. Computes both match-level precision (legacy) and sample-level precision to guard against extreme long-text outliers skewing the metric. """ rows = [] for r in rules: if not r.enabled: continue rid = r.rid sub = hit_df[hit_df["rid"] == rid] total_hits = int(sub["match_count"].sum()) tp = int(sub[sub["label"] == 1]["match_count"].sum()) fp = int(sub[sub["label"] == 0]["match_count"].sum()) # Precision at match level (legacy) precision_match = tp / (tp + fp) if (tp + fp) > 0 else 0.0 # Unique samples that fired this rule n_ai_fired = int(sub[sub["label"] == 1]["record_id"].nunique()) n_human_fired = int(sub[sub["label"] == 0]["record_id"].nunique()) # Sample-level precision: each sample counts once regardless of hit count precision_sample = n_ai_fired / (n_ai_fired + n_human_fired) if (n_ai_fired + n_human_fired) > 0 else 0.0 # Denominators: total samples in each class for this split set total_ai = int(hit_df[hit_df["label"] == 1]["record_id"].nunique()) total_human = int(hit_df[hit_df["label"] == 0]["record_id"].nunique()) recall_ai = n_ai_fired / total_ai if total_ai > 0 else 0.0 fpr_human = n_human_fired / total_human if total_human > 0 else 0.0 rows.append({ "rid": rid, "rule_name": r.name, "category": r.category, "weight": r.weight, "total_match_hits": total_hits, "tp_matches": tp, "fp_matches": fp, "precision_match": round(precision_match, 4), "precision_sample": round(precision_sample, 4), "recall_ai": round(recall_ai, 4), "fpr_human": round(fpr_human, 4), "ai_samples_fired": n_ai_fired, "human_samples_fired": n_human_fired, "total_ai_samples": total_ai, "total_human_samples": total_human, }) return rows def run_e03(args: argparse.Namespace) -> dict: setup_logging(args.log_level) logger = logging.getLogger("E03") ctx = create_run_context(eid="E03", output_root=Path(args.output_root), run_name=args.run_name) logger.info(f"E03 Rule module eval start | run_name={ctx.run_name} | smoke={args.smoke}") manifest = load_dataset_manifest(Path(args.manifest_file)) ds_meta = get_ds_meta(manifest, "DS01") logger.info(f"Dataset: {ds_meta['dataset_id']} | dir={ds_meta['dataset_dir']}") splits = prepare_primary_ds_for_train(ds_meta, smoke=args.smoke, seed=args.seed) for sp in SPLITS: logger.info(f" {sp}: {len(splits[sp])} rows") engine = _DEFAULT_ENGINE logger.info(f"Active rules: {sum(1 for r in engine.rules if r.enabled)}") all_hits = [] pred_splits = {} split_metrics = {} for sp in SPLITS: if splits[sp].empty: logger.warning(f"{sp} is empty, skipping.") continue hit_df = evaluate_rules_on_split(splits[sp], engine) hit_df["split"] = sp all_hits.append(hit_df) # Build prediction split for downstream consumers (E04/E05/E07) pred_df = splits[sp][["record_id", "text", "label", "source", "split", "length_char"]].copy() pred_df["score"] = pred_df["text"].apply(lambda t: engine.score(t)[0]) pred_df["pred"] = (pred_df["score"] >= 0.5).astype(int) pred_splits[sp] = pred_df # Aggregate per-split quick stats fired = hit_df[hit_df["rid"] != ""] n_fired = fired["record_id"].nunique() n_total = len(splits[sp]) split_metrics[sp] = {"samples": n_total, "fired": n_fired, "coverage": round(n_fired / n_total, 4) if n_total else 0.0} logger.info(f"{sp} coverage: {n_fired}/{n_total} ({split_metrics[sp]['coverage']:.2%})") if not all_hits: raise ValueError("No dev/test data available for rule evaluation.") combined_hits = pd.concat(all_hits, ignore_index=True) rule_precisions = compute_rule_precision(combined_hits, engine.rules) # Save outputs save_pred_files(ctx, pred_splits) combined_hits.to_csv(ctx.run_dir / "rule_hits.csv", index=False, encoding="utf-8-sig") write_csv(ctx.run_dir / "rule_precision.csv", rule_precisions) # Determine determinative candidates (high sample-level precision + reasonable coverage) determinative_candidates = [ r for r in rule_precisions if r["precision_sample"] >= 0.85 and r["recall_ai"] >= 0.05 ] logger.info(f"Determinative candidates (precision_sample>=0.85 & recall_ai>=0.05): {len(determinative_candidates)}") for c in determinative_candidates: logger.info(f" - {c['rid']}: precision_sample={c['precision_sample']}, recall_ai={c['recall_ai']}") config = { "seed": args.seed, "smoke": args.smoke, "active_rules": sum(1 for r in engine.rules if r.enabled), "determinative_candidates": len(determinative_candidates), } write_yaml_minimal(ctx.config_file, config) result = { "dev_coverage": split_metrics.get("dev", {}).get("coverage", 0.0), "test_coverage": split_metrics.get("test", {}).get("coverage", 0.0), "determinative_candidates": len(determinative_candidates), } write_run_manifest(ctx, status="success", payload=result) write_run_report(ctx, status="success", config=config, payload=result) logger.info("E03 complete.") return result def main() -> int: parser = argparse.ArgumentParser(description="E03 Rule module evaluation") parser = add_base_args(parser) args = parser.parse_args() try: run_e03(args) return 0 except Exception as e: logging.getLogger("E03").error(f"ERROR: {e}", exc_info=True) raise if __name__ == "__main__": raise SystemExit(main())