subnet32-llm-detector / scripts /run_fusion_and_tune.py
ThaoTran7's picture
incomplete commit
485127c
#!/usr/bin/env python3
"""
Fusion + threshold tuning for Subnet32-style miners.
Default (--fusion_mode weighted):
1) Impute missing branch scores with medians from --fit.
2) Normalize ada / l2d / sup on the fit split (min-max or z-score→sigmoid).
3) final_score = w_ada*ada_n + w_l2d*l2d_n + w_sup*sup_n (default 0.45 / 0.30 / 0.25).
4) Sweep threshold on --tune; maximize weighted reward (default emphasizes FP_score).
Optional --fusion_mode logistic: StandardScaler + logistic regression on branch scores
+ num_words + num_sentences (previous behavior).
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from typing import Any, Dict, List, Tuple
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
_ROOT = os.path.join(os.path.dirname(__file__), "..")
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
from miner_lab.subnet_metrics import metrics_at_threshold, stress_report, sweep_thresholds # noqa: E402
from miner_lab.text_features import count_sentences, count_words # noqa: E402
from miner_lab.weighted_fusion import ( # noqa: E402
DEFAULT_WEIGHTS,
compute_medians,
fit_minmax,
fit_zscore,
fusion_prob_array,
renormalize_weights,
)
SCORE_KEYS = ("ada_score", "l2d_score", "sup_score")
def load_jsonl(path: str) -> List[dict]:
rows: List[dict] = []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
rows.append(json.loads(line))
return rows
def _finite_float(x) -> bool:
try:
v = float(x)
return np.isfinite(v)
except (TypeError, ValueError):
return False
def compute_medians_legacy(rows: List[dict], keys: Tuple[str, ...]) -> Dict[str, float]:
med: Dict[str, float] = {}
for k in keys:
vals = [float(r[k]) for r in rows if k in r and _finite_float(r[k])]
med[k] = float(np.median(vals)) if vals else 0.0
return med
def build_matrix(
rows: List[dict],
feature_names: List[str],
medians: Dict[str, float],
) -> Tuple[np.ndarray, np.ndarray]:
X: List[List[float]] = []
y: List[int] = []
for r in rows:
y.append(int(r["label"]))
vec: List[float] = []
for f in feature_names:
if f == "num_words":
vec.append(float(count_words(r.get("text", ""))))
elif f == "num_sentences":
vec.append(float(count_sentences(r.get("text", ""))))
else:
v = r.get(f)
if v is None or not _finite_float(v):
v = medians.get(f, 0.0)
vec.append(float(v))
X.append(vec)
return np.asarray(X, dtype=np.float64), np.asarray(y, dtype=np.int64)
def attach_fusion_scores(rows: List[dict], probs: np.ndarray, key: str = "fusion_score") -> None:
for r, p in zip(rows, probs):
r[key] = float(p)
_STRESS_HELP = """
Stress merged JSONL must exist before using --stress. Build it like val.merged:
- Start from exp_main/data/subnet32_stress.samples.jsonl (see holdout_pairs_jsonl.py).
- Run run_ada_jsonl.py on that JSONL, then merge_score_jsonl (--base stress.samples --ada ...).
- Optionally merge sup_score the same way as val.
Or omit --stress to only fit/tune and write artifact + report (no stress block).
"""
def main():
p = argparse.ArgumentParser(
description="Weighted normalized fusion (default) or logistic fusion + threshold tuning.",
epilog=_STRESS_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument("--fit", required=True, help="Merged JSONL to fit normalization / LR (e.g. train)")
p.add_argument("--tune", required=True, help="Merged JSONL for threshold sweep (e.g. val)")
p.add_argument("--stress", help="Merged JSONL for stress report")
p.add_argument("--test", help="Optional merged JSONL for locked test metrics")
p.add_argument(
"--fusion_mode",
choices=("weighted", "logistic"),
default="weighted",
help="weighted: normalize branch scores on --fit then fixed weights (default Subnet32 recipe). "
"logistic: StandardScaler + LR on scores + length features.",
)
p.add_argument(
"--norm",
choices=("minmax", "zscore_sigmoid"),
default="minmax",
help="Branch normalization (weighted mode only).",
)
p.add_argument("--w_ada", type=float, default=DEFAULT_WEIGHTS["ada_score"])
p.add_argument("--w_l2d", type=float, default=DEFAULT_WEIGHTS["l2d_score"])
p.add_argument("--w_sup", type=float, default=DEFAULT_WEIGHTS["sup_score"])
p.add_argument("--omit_l2d", action="store_true", help="Drop l2d_score branch (weights renormalized)")
p.add_argument("--omit_sup", action="store_true", help="Drop sup_score branch (weights renormalized)")
p.add_argument(
"--reward_w_f1",
type=float,
default=0.25,
help="Weight for F1 in reward (default 0.25; increase FP emphasis via --reward_w_fp).",
)
p.add_argument("--reward_w_fp", type=float, default=0.5, help="Weight for FP_score in reward (default 0.5).")
p.add_argument("--reward_w_ap", type=float, default=0.25, help="Weight for AP in reward (default 0.25).")
p.add_argument(
"--artifact",
default="./exp_main/results/fusion_artifact.json",
help="Write fusion metadata for deployment",
)
p.add_argument(
"--report",
default="./exp_main/results/fusion_stress_report.json",
help="Write JSON report",
)
args = p.parse_args()
reward_weights = (args.reward_w_f1, args.reward_w_fp, args.reward_w_ap)
def _need(path: str, flag: str) -> None:
if not os.path.isfile(path):
extra = ""
if flag == "--stress":
extra = "\n" + _STRESS_HELP.strip()
raise SystemExit(f"{flag} file not found: {path!r}{extra}")
_need(args.fit, "--fit")
_need(args.tune, "--tune")
if args.stress:
_need(args.stress, "--stress")
if args.test:
_need(args.test, "--test")
fit_rows = load_jsonl(args.fit)
tune_rows = load_jsonl(args.tune)
score_keys = list(SCORE_KEYS)
if args.omit_l2d:
score_keys = [k for k in score_keys if k != "l2d_score"]
if args.omit_sup:
score_keys = [k for k in score_keys if k != "sup_score"]
branch_weights = renormalize_weights(
{"ada_score": args.w_ada, "l2d_score": args.w_l2d, "sup_score": args.w_sup},
score_keys,
)
if args.fusion_mode == "weighted":
medians = compute_medians(fit_rows, score_keys)
if args.norm == "minmax":
norm_params = fit_minmax(fit_rows, score_keys, medians)
else:
norm_params = fit_zscore(fit_rows, score_keys, medians)
tune_probs = fusion_prob_array(
tune_rows, score_keys, medians, norm_params, branch_weights, args.norm
)
best_t, best_metrics, sweep_rows = sweep_thresholds(
np.array([int(r["label"]) for r in tune_rows], dtype=np.int64),
tune_probs,
reward_weights=reward_weights,
)
print("Fusion mode: weighted (normalize on --fit, then branch weights)")
print("Branch weights:", branch_weights)
print("Norm:", args.norm)
print("Best threshold on tune split:", best_t)
print(json.dumps(best_metrics, indent=2))
attach_fusion_scores(tune_rows, tune_probs)
report: Dict[str, Any] = {
"fusion_mode": "weighted",
"norm_method": args.norm,
"score_keys": score_keys,
"branch_weights": branch_weights,
"medians": medians,
"norm_params": norm_params,
"reward_weights": {"f1": args.reward_w_f1, "fp_score": args.reward_w_fp, "ap": args.reward_w_ap},
"best_threshold": best_t,
"tune_metrics_at_best": best_metrics,
"sweep_head": sweep_rows[:5],
"sweep_tail": sweep_rows[-5:],
}
artifact: Dict[str, Any] = {
"fusion_mode": "weighted",
"norm_method": args.norm,
"score_keys": score_keys,
"branch_weights": branch_weights,
"medians": medians,
"norm_params": norm_params,
"best_threshold": best_t,
"reward_weights": [args.reward_w_f1, args.reward_w_fp, args.reward_w_ap],
}
if args.stress:
stress_rows = load_jsonl(args.stress)
ys = np.array([int(r["label"]) for r in stress_rows], dtype=np.int64)
s_probs = fusion_prob_array(
stress_rows, score_keys, medians, norm_params, branch_weights, args.norm
)
attach_fusion_scores(stress_rows, s_probs)
report["stress"] = stress_report(
stress_rows,
score_key="fusion_score",
threshold=best_t,
slice_keys=("domain", "generator_family", "augmentation_type", "split"),
reward_weights=reward_weights,
)
st_m = metrics_at_threshold(ys, s_probs, best_t, reward_weights=reward_weights)
report["stress_scalar"] = st_m
print("Stress metrics at best threshold:", json.dumps(st_m, indent=2))
if args.test:
test_rows = load_jsonl(args.test)
ye = np.array([int(r["label"]) for r in test_rows], dtype=np.int64)
e_probs = fusion_prob_array(
test_rows, score_keys, medians, norm_params, branch_weights, args.norm
)
report["test"] = metrics_at_threshold(ye, e_probs, best_t, reward_weights=reward_weights)
else:
feature_names = list(score_keys) + ["num_words", "num_sentences"]
medians = compute_medians_legacy(fit_rows, tuple(score_keys))
X_fit, y_fit = build_matrix(fit_rows, feature_names, medians)
X_tune, y_tune = build_matrix(tune_rows, feature_names, medians)
scaler = StandardScaler()
Xf = scaler.fit_transform(X_fit)
Xt = scaler.transform(X_tune)
clf = LogisticRegression(max_iter=4000, random_state=42)
clf.fit(Xf, y_fit)
tune_probs = clf.predict_proba(Xt)[:, 1]
best_t, best_metrics, sweep_rows = sweep_thresholds(
y_tune, tune_probs, reward_weights=reward_weights
)
print("Fusion mode: logistic (Scaler + LR on scores + num_words + num_sentences)")
print("Best threshold on tune split:", best_t)
print(json.dumps(best_metrics, indent=2))
attach_fusion_scores(tune_rows, tune_probs)
report = {
"fusion_mode": "logistic",
"feature_names": feature_names,
"medians": medians,
"reward_weights": {"f1": args.reward_w_f1, "fp_score": args.reward_w_fp, "ap": args.reward_w_ap},
"best_threshold": best_t,
"tune_metrics_at_best": best_metrics,
"sweep_head": sweep_rows[:5],
"sweep_tail": sweep_rows[-5:],
}
artifact = {
"fusion_mode": "logistic",
"feature_names": feature_names,
"medians": medians,
"scaler_mean": scaler.mean_.tolist(),
"scaler_scale": scaler.scale_.tolist(),
"lr_coef": clf.coef_.ravel().tolist(),
"lr_intercept": float(clf.intercept_[0]),
"best_threshold": best_t,
"reward_weights": [args.reward_w_f1, args.reward_w_fp, args.reward_w_ap],
}
if args.stress:
stress_rows = load_jsonl(args.stress)
Xs, ys = build_matrix(stress_rows, feature_names, medians)
Xs = scaler.transform(Xs)
s_probs = clf.predict_proba(Xs)[:, 1]
attach_fusion_scores(stress_rows, s_probs)
report["stress"] = stress_report(
stress_rows,
score_key="fusion_score",
threshold=best_t,
slice_keys=("domain", "generator_family", "augmentation_type", "split"),
reward_weights=reward_weights,
)
st_m = metrics_at_threshold(ys, s_probs, best_t, reward_weights=reward_weights)
report["stress_scalar"] = st_m
print("Stress metrics at best threshold:", json.dumps(st_m, indent=2))
if args.test:
test_rows = load_jsonl(args.test)
Xe, ye = build_matrix(test_rows, feature_names, medians)
Xe = scaler.transform(Xe)
e_probs = clf.predict_proba(Xe)[:, 1]
report["test"] = metrics_at_threshold(ye, e_probs, best_t, reward_weights=reward_weights)
os.makedirs(os.path.dirname(os.path.abspath(args.artifact)) or ".", exist_ok=True)
with open(args.artifact, "w", encoding="utf-8") as f:
json.dump(artifact, f, indent=2)
print(f"Wrote artifact {args.artifact}")
os.makedirs(os.path.dirname(os.path.abspath(args.report)) or ".", exist_ok=True)
with open(args.report, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2)
print(f"Wrote report {args.report}")
if __name__ == "__main__":
main()