| |
| """ |
| Fusion + threshold tuning for Subnet32-style miners. |
| |
| Default (--fusion_mode weighted): |
| 1) Impute missing branch scores with medians from --fit. |
| 2) Normalize ada / l2d / sup on the fit split (min-max or z-score→sigmoid). |
| 3) final_score = w_ada*ada_n + w_l2d*l2d_n + w_sup*sup_n (default 0.45 / 0.30 / 0.25). |
| 4) Sweep threshold on --tune; maximize weighted reward (default emphasizes FP_score). |
| |
| Optional --fusion_mode logistic: StandardScaler + logistic regression on branch scores |
| + num_words + num_sentences (previous behavior). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from typing import Any, Dict, List, Tuple |
|
|
| import numpy as np |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
|
|
| _ROOT = os.path.join(os.path.dirname(__file__), "..") |
| if _ROOT not in sys.path: |
| sys.path.insert(0, _ROOT) |
|
|
| from miner_lab.subnet_metrics import metrics_at_threshold, stress_report, sweep_thresholds |
| from miner_lab.text_features import count_sentences, count_words |
| from miner_lab.weighted_fusion import ( |
| DEFAULT_WEIGHTS, |
| compute_medians, |
| fit_minmax, |
| fit_zscore, |
| fusion_prob_array, |
| renormalize_weights, |
| ) |
|
|
| SCORE_KEYS = ("ada_score", "l2d_score", "sup_score") |
|
|
|
|
| def load_jsonl(path: str) -> List[dict]: |
| rows: List[dict] = [] |
| with open(path, encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| rows.append(json.loads(line)) |
| return rows |
|
|
|
|
| def _finite_float(x) -> bool: |
| try: |
| v = float(x) |
| return np.isfinite(v) |
| except (TypeError, ValueError): |
| return False |
|
|
|
|
| def compute_medians_legacy(rows: List[dict], keys: Tuple[str, ...]) -> Dict[str, float]: |
| med: Dict[str, float] = {} |
| for k in keys: |
| vals = [float(r[k]) for r in rows if k in r and _finite_float(r[k])] |
| med[k] = float(np.median(vals)) if vals else 0.0 |
| return med |
|
|
|
|
| def build_matrix( |
| rows: List[dict], |
| feature_names: List[str], |
| medians: Dict[str, float], |
| ) -> Tuple[np.ndarray, np.ndarray]: |
| X: List[List[float]] = [] |
| y: List[int] = [] |
| for r in rows: |
| y.append(int(r["label"])) |
| vec: List[float] = [] |
| for f in feature_names: |
| if f == "num_words": |
| vec.append(float(count_words(r.get("text", "")))) |
| elif f == "num_sentences": |
| vec.append(float(count_sentences(r.get("text", "")))) |
| else: |
| v = r.get(f) |
| if v is None or not _finite_float(v): |
| v = medians.get(f, 0.0) |
| vec.append(float(v)) |
| X.append(vec) |
| return np.asarray(X, dtype=np.float64), np.asarray(y, dtype=np.int64) |
|
|
|
|
| def attach_fusion_scores(rows: List[dict], probs: np.ndarray, key: str = "fusion_score") -> None: |
| for r, p in zip(rows, probs): |
| r[key] = float(p) |
|
|
|
|
| _STRESS_HELP = """ |
| Stress merged JSONL must exist before using --stress. Build it like val.merged: |
| - Start from exp_main/data/subnet32_stress.samples.jsonl (see holdout_pairs_jsonl.py). |
| - Run run_ada_jsonl.py on that JSONL, then merge_score_jsonl (--base stress.samples --ada ...). |
| - Optionally merge sup_score the same way as val. |
| Or omit --stress to only fit/tune and write artifact + report (no stress block). |
| """ |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser( |
| description="Weighted normalized fusion (default) or logistic fusion + threshold tuning.", |
| epilog=_STRESS_HELP, |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| ) |
| p.add_argument("--fit", required=True, help="Merged JSONL to fit normalization / LR (e.g. train)") |
| p.add_argument("--tune", required=True, help="Merged JSONL for threshold sweep (e.g. val)") |
| p.add_argument("--stress", help="Merged JSONL for stress report") |
| p.add_argument("--test", help="Optional merged JSONL for locked test metrics") |
| p.add_argument( |
| "--fusion_mode", |
| choices=("weighted", "logistic"), |
| default="weighted", |
| help="weighted: normalize branch scores on --fit then fixed weights (default Subnet32 recipe). " |
| "logistic: StandardScaler + LR on scores + length features.", |
| ) |
| p.add_argument( |
| "--norm", |
| choices=("minmax", "zscore_sigmoid"), |
| default="minmax", |
| help="Branch normalization (weighted mode only).", |
| ) |
| p.add_argument("--w_ada", type=float, default=DEFAULT_WEIGHTS["ada_score"]) |
| p.add_argument("--w_l2d", type=float, default=DEFAULT_WEIGHTS["l2d_score"]) |
| p.add_argument("--w_sup", type=float, default=DEFAULT_WEIGHTS["sup_score"]) |
| p.add_argument("--omit_l2d", action="store_true", help="Drop l2d_score branch (weights renormalized)") |
| p.add_argument("--omit_sup", action="store_true", help="Drop sup_score branch (weights renormalized)") |
| p.add_argument( |
| "--reward_w_f1", |
| type=float, |
| default=0.25, |
| help="Weight for F1 in reward (default 0.25; increase FP emphasis via --reward_w_fp).", |
| ) |
| p.add_argument("--reward_w_fp", type=float, default=0.5, help="Weight for FP_score in reward (default 0.5).") |
| p.add_argument("--reward_w_ap", type=float, default=0.25, help="Weight for AP in reward (default 0.25).") |
| p.add_argument( |
| "--artifact", |
| default="./exp_main/results/fusion_artifact.json", |
| help="Write fusion metadata for deployment", |
| ) |
| p.add_argument( |
| "--report", |
| default="./exp_main/results/fusion_stress_report.json", |
| help="Write JSON report", |
| ) |
| args = p.parse_args() |
|
|
| reward_weights = (args.reward_w_f1, args.reward_w_fp, args.reward_w_ap) |
|
|
| def _need(path: str, flag: str) -> None: |
| if not os.path.isfile(path): |
| extra = "" |
| if flag == "--stress": |
| extra = "\n" + _STRESS_HELP.strip() |
| raise SystemExit(f"{flag} file not found: {path!r}{extra}") |
|
|
| _need(args.fit, "--fit") |
| _need(args.tune, "--tune") |
| if args.stress: |
| _need(args.stress, "--stress") |
| if args.test: |
| _need(args.test, "--test") |
|
|
| fit_rows = load_jsonl(args.fit) |
| tune_rows = load_jsonl(args.tune) |
|
|
| score_keys = list(SCORE_KEYS) |
| if args.omit_l2d: |
| score_keys = [k for k in score_keys if k != "l2d_score"] |
| if args.omit_sup: |
| score_keys = [k for k in score_keys if k != "sup_score"] |
|
|
| branch_weights = renormalize_weights( |
| {"ada_score": args.w_ada, "l2d_score": args.w_l2d, "sup_score": args.w_sup}, |
| score_keys, |
| ) |
|
|
| if args.fusion_mode == "weighted": |
| medians = compute_medians(fit_rows, score_keys) |
| if args.norm == "minmax": |
| norm_params = fit_minmax(fit_rows, score_keys, medians) |
| else: |
| norm_params = fit_zscore(fit_rows, score_keys, medians) |
|
|
| tune_probs = fusion_prob_array( |
| tune_rows, score_keys, medians, norm_params, branch_weights, args.norm |
| ) |
|
|
| best_t, best_metrics, sweep_rows = sweep_thresholds( |
| np.array([int(r["label"]) for r in tune_rows], dtype=np.int64), |
| tune_probs, |
| reward_weights=reward_weights, |
| ) |
| print("Fusion mode: weighted (normalize on --fit, then branch weights)") |
| print("Branch weights:", branch_weights) |
| print("Norm:", args.norm) |
| print("Best threshold on tune split:", best_t) |
| print(json.dumps(best_metrics, indent=2)) |
|
|
| attach_fusion_scores(tune_rows, tune_probs) |
|
|
| report: Dict[str, Any] = { |
| "fusion_mode": "weighted", |
| "norm_method": args.norm, |
| "score_keys": score_keys, |
| "branch_weights": branch_weights, |
| "medians": medians, |
| "norm_params": norm_params, |
| "reward_weights": {"f1": args.reward_w_f1, "fp_score": args.reward_w_fp, "ap": args.reward_w_ap}, |
| "best_threshold": best_t, |
| "tune_metrics_at_best": best_metrics, |
| "sweep_head": sweep_rows[:5], |
| "sweep_tail": sweep_rows[-5:], |
| } |
|
|
| artifact: Dict[str, Any] = { |
| "fusion_mode": "weighted", |
| "norm_method": args.norm, |
| "score_keys": score_keys, |
| "branch_weights": branch_weights, |
| "medians": medians, |
| "norm_params": norm_params, |
| "best_threshold": best_t, |
| "reward_weights": [args.reward_w_f1, args.reward_w_fp, args.reward_w_ap], |
| } |
|
|
| if args.stress: |
| stress_rows = load_jsonl(args.stress) |
| ys = np.array([int(r["label"]) for r in stress_rows], dtype=np.int64) |
| s_probs = fusion_prob_array( |
| stress_rows, score_keys, medians, norm_params, branch_weights, args.norm |
| ) |
| attach_fusion_scores(stress_rows, s_probs) |
| report["stress"] = stress_report( |
| stress_rows, |
| score_key="fusion_score", |
| threshold=best_t, |
| slice_keys=("domain", "generator_family", "augmentation_type", "split"), |
| reward_weights=reward_weights, |
| ) |
| st_m = metrics_at_threshold(ys, s_probs, best_t, reward_weights=reward_weights) |
| report["stress_scalar"] = st_m |
| print("Stress metrics at best threshold:", json.dumps(st_m, indent=2)) |
|
|
| if args.test: |
| test_rows = load_jsonl(args.test) |
| ye = np.array([int(r["label"]) for r in test_rows], dtype=np.int64) |
| e_probs = fusion_prob_array( |
| test_rows, score_keys, medians, norm_params, branch_weights, args.norm |
| ) |
| report["test"] = metrics_at_threshold(ye, e_probs, best_t, reward_weights=reward_weights) |
|
|
| else: |
| feature_names = list(score_keys) + ["num_words", "num_sentences"] |
| medians = compute_medians_legacy(fit_rows, tuple(score_keys)) |
|
|
| X_fit, y_fit = build_matrix(fit_rows, feature_names, medians) |
| X_tune, y_tune = build_matrix(tune_rows, feature_names, medians) |
|
|
| scaler = StandardScaler() |
| Xf = scaler.fit_transform(X_fit) |
| Xt = scaler.transform(X_tune) |
|
|
| clf = LogisticRegression(max_iter=4000, random_state=42) |
| clf.fit(Xf, y_fit) |
| tune_probs = clf.predict_proba(Xt)[:, 1] |
|
|
| best_t, best_metrics, sweep_rows = sweep_thresholds( |
| y_tune, tune_probs, reward_weights=reward_weights |
| ) |
| print("Fusion mode: logistic (Scaler + LR on scores + num_words + num_sentences)") |
| print("Best threshold on tune split:", best_t) |
| print(json.dumps(best_metrics, indent=2)) |
|
|
| attach_fusion_scores(tune_rows, tune_probs) |
|
|
| report = { |
| "fusion_mode": "logistic", |
| "feature_names": feature_names, |
| "medians": medians, |
| "reward_weights": {"f1": args.reward_w_f1, "fp_score": args.reward_w_fp, "ap": args.reward_w_ap}, |
| "best_threshold": best_t, |
| "tune_metrics_at_best": best_metrics, |
| "sweep_head": sweep_rows[:5], |
| "sweep_tail": sweep_rows[-5:], |
| } |
|
|
| artifact = { |
| "fusion_mode": "logistic", |
| "feature_names": feature_names, |
| "medians": medians, |
| "scaler_mean": scaler.mean_.tolist(), |
| "scaler_scale": scaler.scale_.tolist(), |
| "lr_coef": clf.coef_.ravel().tolist(), |
| "lr_intercept": float(clf.intercept_[0]), |
| "best_threshold": best_t, |
| "reward_weights": [args.reward_w_f1, args.reward_w_fp, args.reward_w_ap], |
| } |
|
|
| if args.stress: |
| stress_rows = load_jsonl(args.stress) |
| Xs, ys = build_matrix(stress_rows, feature_names, medians) |
| Xs = scaler.transform(Xs) |
| s_probs = clf.predict_proba(Xs)[:, 1] |
| attach_fusion_scores(stress_rows, s_probs) |
| report["stress"] = stress_report( |
| stress_rows, |
| score_key="fusion_score", |
| threshold=best_t, |
| slice_keys=("domain", "generator_family", "augmentation_type", "split"), |
| reward_weights=reward_weights, |
| ) |
| st_m = metrics_at_threshold(ys, s_probs, best_t, reward_weights=reward_weights) |
| report["stress_scalar"] = st_m |
| print("Stress metrics at best threshold:", json.dumps(st_m, indent=2)) |
|
|
| if args.test: |
| test_rows = load_jsonl(args.test) |
| Xe, ye = build_matrix(test_rows, feature_names, medians) |
| Xe = scaler.transform(Xe) |
| e_probs = clf.predict_proba(Xe)[:, 1] |
| report["test"] = metrics_at_threshold(ye, e_probs, best_t, reward_weights=reward_weights) |
|
|
| os.makedirs(os.path.dirname(os.path.abspath(args.artifact)) or ".", exist_ok=True) |
| with open(args.artifact, "w", encoding="utf-8") as f: |
| json.dump(artifact, f, indent=2) |
| print(f"Wrote artifact {args.artifact}") |
|
|
| os.makedirs(os.path.dirname(os.path.abspath(args.report)) or ".", exist_ok=True) |
| with open(args.report, "w", encoding="utf-8") as f: |
| json.dump(report, f, indent=2) |
| print(f"Wrote report {args.report}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|