| from __future__ import annotations |
|
|
| import json |
| import re |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Dict, List, Sequence |
|
|
| import numpy as np |
|
|
| from ..config import load_config |
| from ..data.dataset import load_and_prepare_dataset |
| from ..logging_utils import get_logger |
| from ..models.baseline_detoxify import DetoxifyPredictor |
| from ..models.hf_model import HFPredictor |
| from ..training.metrics import evaluate_multilabel |
| from ..utils import resolve_paths, sha256_text |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| _RE_URL = re.compile(r"https?://\S+|www\.\S+", flags=re.IGNORECASE) |
| _RE_EMAIL = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}", flags=re.IGNORECASE) |
| _RE_REPEAT = re.compile(r"(.)\1{2,}") |
| _PUNCT = set("!?.:,;\"'()[]{}<>/\\|@#$%^&*_+-=~`") |
|
|
|
|
| @dataclass |
| class ErrorAnalysisReport: |
| label_fields: List[str] |
| n_samples: int |
| threshold: float |
| model: Dict[str, Any] |
| overall_metrics: Dict[str, Any] |
| confusion_per_label: Dict[str, Any] |
| feature_summary: Dict[str, Any] |
| top_error_cases: List[Dict[str, Any]] |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| return { |
| "label_fields": self.label_fields, |
| "n_samples": self.n_samples, |
| "threshold": self.threshold, |
| "model": self.model, |
| "overall_metrics": self.overall_metrics, |
| "confusion_per_label": self.confusion_per_label, |
| "feature_summary": self.feature_summary, |
| "top_error_cases": self.top_error_cases, |
| } |
|
|
|
|
| def _cuda_available() -> bool: |
| try: |
| import torch |
|
|
| return bool(torch.cuda.is_available()) |
| except Exception: |
| return False |
|
|
|
|
| def _extract_features(texts: Sequence[str]) -> Dict[str, np.ndarray]: |
| n = len(texts) |
| length_chars = np.zeros((n,), dtype=np.int32) |
| length_words = np.zeros((n,), dtype=np.int32) |
| uppercase_ratio = np.zeros((n,), dtype=np.float32) |
| punct_ratio = np.zeros((n,), dtype=np.float32) |
| exclam = np.zeros((n,), dtype=np.int32) |
| question = np.zeros((n,), dtype=np.int32) |
| has_url = np.zeros((n,), dtype=np.int8) |
| has_email = np.zeros((n,), dtype=np.int8) |
| has_repeat = np.zeros((n,), dtype=np.int8) |
| non_ascii_ratio = np.zeros((n,), dtype=np.float32) |
| has_quote = np.zeros((n,), dtype=np.int8) |
|
|
| for i, t in enumerate(texts): |
| s = t or "" |
| length_chars[i] = len(s) |
| words = s.split() |
| length_words[i] = len(words) |
|
|
| if len(s) > 0: |
| upper = sum(1 for c in s if c.isalpha() and c.isupper()) |
| alpha = sum(1 for c in s if c.isalpha()) |
| uppercase_ratio[i] = (upper / alpha) if alpha > 0 else 0.0 |
|
|
| punct = sum(1 for c in s if c in _PUNCT) |
| punct_ratio[i] = punct / len(s) |
|
|
| non_ascii = sum(1 for c in s if ord(c) > 127) |
| non_ascii_ratio[i] = non_ascii / len(s) |
|
|
| exclam[i] = s.count("!") |
| question[i] = s.count("?") |
| has_url[i] = 1 if _RE_URL.search(s) else 0 |
| has_email[i] = 1 if _RE_EMAIL.search(s) else 0 |
| has_repeat[i] = 1 if _RE_REPEAT.search(s) else 0 |
| has_quote[i] = 1 if ('"' in s or "'" in s) else 0 |
|
|
| return { |
| "length_chars": length_chars, |
| "length_words": length_words, |
| "uppercase_ratio": uppercase_ratio, |
| "punct_ratio": punct_ratio, |
| "exclam": exclam, |
| "question": question, |
| "has_url": has_url, |
| "has_email": has_email, |
| "has_repeat": has_repeat, |
| "non_ascii_ratio": non_ascii_ratio, |
| "has_quote": has_quote, |
| } |
|
|
|
|
| def _summarize_features(features: Dict[str, np.ndarray], mask: np.ndarray) -> Dict[str, Any]: |
| """Return a compact numeric summary of features for subset mask.""" |
| idx = np.where(mask)[0] |
| if idx.size == 0: |
| return {"n": 0} |
|
|
| out: Dict[str, Any] = {"n": int(idx.size)} |
| for k, arr in features.items(): |
| sub = arr[idx] |
| |
| out[k] = { |
| "mean": float(np.mean(sub)), |
| "median": float(np.median(sub)), |
| "p90": float(np.percentile(sub, 90)), |
| } |
| return out |
|
|
|
|
| def _confusion_counts(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, int]: |
| |
| tp = int(((y_true == 1) & (y_pred == 1)).sum()) |
| fp = int(((y_true == 0) & (y_pred == 1)).sum()) |
| fn = int(((y_true == 1) & (y_pred == 0)).sum()) |
| tn = int(((y_true == 0) & (y_pred == 0)).sum()) |
| return {"tp": tp, "fp": fp, "fn": fn, "tn": tn} |
|
|
|
|
| def run_error_analysis( |
| *, |
| train_config_path: str, |
| split: str = "test", |
| threshold: float = 0.5, |
| max_samples: int | None = None, |
| model_kind: str = "finetuned", |
| top_k_cases: int = 50, |
| ) -> Dict[str, Any]: |
| """Run privacy-preserving error analysis. |
| |
| - No raw text is printed or written. |
| - We only store hashed ids + numeric features for a few top error cases. |
| |
| model_kind: |
| - "finetuned" (default) -> models/finetuned/latest |
| - "detoxify-unbiased" -> detoxify baseline |
| """ |
| cfg = load_config(train_config_path) |
| paths_cfg = cfg.get("paths", {}) |
| paths = resolve_paths( |
| data_dir_cfg=str(paths_cfg.get("data_dir", "")), |
| artifacts_dir_cfg=str(paths_cfg.get("artifacts_dir", "")), |
| ) |
|
|
| loaded = load_and_prepare_dataset(cfg) |
| label_fields = loaded.label_fields |
|
|
| ds = loaded.dataset.get(split) |
| if ds is None: |
| raise ValueError(f"Split '{split}' not found. Available: {list(loaded.dataset.keys())}") |
| if max_samples is not None: |
| ds = ds.select(range(min(int(max_samples), len(ds)))) |
|
|
| texts = ds["text"] |
| y_true = np.array(ds["labels"], dtype=np.float32) |
|
|
| device = "cuda" if _cuda_available() else "cpu" |
| probs: np.ndarray |
| model_info: Dict[str, Any] |
|
|
| if model_kind == "finetuned": |
| model_dir = paths.models_dir / "finetuned" / "latest" |
| if not model_dir.exists(): |
| raise FileNotFoundError(f"Fine-tuned model not found: {model_dir}. Run training first.") |
| predictor = HFPredictor(model_dir=model_dir, device=device, max_length=int(cfg["model"]["max_length"])) |
| probs = predictor.predict_proba_matrix(texts, label_order=label_fields, batch_size=64) |
| model_info = {"kind": model_kind, "model_dir": str(model_dir)} |
| elif model_kind == "detoxify-unbiased": |
| detox = DetoxifyPredictor(model_type="unbiased", device=device) |
| label_map = {"identity_hate": "identity_attack"} |
| probs = detox.predict_proba_matrix(texts, label_order=label_fields, label_map=label_map, batch_size=64) |
| model_info = {"kind": model_kind, "detoxify_model_type": "unbiased"} |
| else: |
| raise ValueError(f"Unknown model_kind: {model_kind}") |
|
|
| y_pred = (probs >= float(threshold)).astype(np.int32) |
| overall = evaluate_multilabel(y_true, probs, label_fields, threshold=float(threshold)).to_dict() |
|
|
| |
| confusion = {} |
| for j, lf in enumerate(label_fields): |
| confusion[lf] = _confusion_counts(y_true[:, j].astype(np.int32), y_pred[:, j].astype(np.int32)) |
|
|
| |
| feats = _extract_features(texts) |
|
|
| |
| any_error = np.any(y_pred != y_true.astype(np.int32), axis=1) |
|
|
| |
| per_label_summaries: Dict[str, Any] = {} |
| for j, lf in enumerate(label_fields): |
| yt = y_true[:, j].astype(np.int32) |
| yp = y_pred[:, j].astype(np.int32) |
| fp = (yt == 0) & (yp == 1) |
| fn = (yt == 1) & (yp == 0) |
| per_label_summaries[lf] = { |
| "fp_features": _summarize_features(feats, fp), |
| "fn_features": _summarize_features(feats, fn), |
| } |
|
|
| feature_summary = { |
| "all_samples": _summarize_features(feats, np.ones((len(texts),), dtype=bool)), |
| "any_error": _summarize_features(feats, any_error), |
| "per_label": per_label_summaries, |
| "notes": [ |
| "Feature summaries are numeric aggregates (mean/median/p90).", |
| "No raw text is stored. 'top_error_cases' uses sha256(text) only.", |
| ], |
| } |
|
|
| |
| |
| abs_err = np.abs(probs - y_true) |
| |
| row_score = abs_err.max(axis=1) |
| top_idx = np.argsort(-row_score)[: int(top_k_cases)] |
|
|
| top_cases: List[Dict[str, Any]] = [] |
| for i in top_idx: |
| t = texts[i] |
| |
| case = { |
| "sha256": sha256_text(t), |
| "length_chars": int(feats["length_chars"][i]), |
| "length_words": int(feats["length_words"][i]), |
| "has_url": int(feats["has_url"][i]), |
| "has_email": int(feats["has_email"][i]), |
| "has_quote": int(feats["has_quote"][i]), |
| "uppercase_ratio": float(feats["uppercase_ratio"][i]), |
| "max_abs_error": float(row_score[i]), |
| "true_labels": {lf: int(y_true[i, j] > 0.5) for j, lf in enumerate(label_fields)}, |
| "pred_labels": {lf: int(y_pred[i, j]) for j, lf in enumerate(label_fields)}, |
| "pred_probs": {lf: float(probs[i, j]) for j, lf in enumerate(label_fields)}, |
| } |
| top_cases.append(case) |
|
|
| report = ErrorAnalysisReport( |
| label_fields=label_fields, |
| n_samples=int(len(texts)), |
| threshold=float(threshold), |
| model=model_info, |
| overall_metrics=overall, |
| confusion_per_label=confusion, |
| feature_summary=feature_summary, |
| top_error_cases=top_cases, |
| ) |
| return report.to_dict() |
|
|
|
|
| def save_error_analysis(out_path: Path, report: Dict[str, Any]) -> Path: |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| out_path.write_text(json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8") |
| logger.info(f"Saved error analysis report to {out_path}") |
| return out_path |
|
|