FRET-FACS / models /lr_sequence_composition_baseline.py
neuwirtt
Initial release: FRET-FACS pipeline, weights, and datasets
6e4d123
Raw
History Blame Contribute Delete
28.8 kB
#!/usr/bin/env python3
"""
Composition-only baseline for FRET binary classification (Figure 3).
Logistic regression on simple sequence-derived features — no ESM embeddings:
- Hydrophobicity (mean Kyte-Doolittle GRAVY)
- Isoelectric point (pI, Biopython)
- Net charge at pH 7 (Biopython)
- Amino-acid group fractions (hydrophobic, basic, sulphur, amides, aromatic, hydroxy, acidic)
- Sequence complexity (Shannon entropy of amino-acid composition)
Uses the same stratified 80/20 train/test split and validation protocol as the
other model scripts (temperature scaling + F1-optimal threshold on validation;
test ROC-AUC reported on held-out set).
If test ROC-AUC >= 0.85, composition alone may explain much of the ESM signal.
"""
from __future__ import annotations
import argparse
import json
import math
import os
import secrets
from collections import Counter
from typing import Dict, List, Optional, Tuple
import esm
import joblib
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
accuracy_score,
average_precision_score,
classification_report,
confusion_matrix,
roc_auc_score,
roc_curve,
)
from sklearn.inspection import permutation_importance
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
try:
from Bio.SeqUtils.IsoelectricPoint import IsoelectricPoint
from Bio.SeqUtils.ProtParam import ProteinAnalysis
except ImportError as exc:
raise ImportError(
"This script requires Biopython (pip install biopython). "
"It is used for isoelectric point and charge-at-pH features."
) from exc
# Kyte-Doolittle hydropathy (same table as structure_predictions_processing/esmfold_processing_DSSP.py)
HYDROPATHY = {
"A": 1.8,
"C": 2.5,
"D": -3.5,
"E": -3.5,
"F": 2.8,
"G": -0.4,
"H": -3.2,
"I": 4.5,
"K": -3.9,
"L": 3.8,
"M": 1.9,
"N": -3.5,
"P": -1.6,
"Q": -3.5,
"R": -4.5,
"S": -0.8,
"T": -0.7,
"V": 4.2,
"W": -0.9,
"Y": -1.3,
}
CANONICAL_AAS = set(HYDROPATHY)
AA_GROUPS = {
"frac_hydrophobic": frozenset("LVIGAP"),
"frac_basic": frozenset("RKH"),
"frac_sulphur": frozenset("MC"),
"frac_amides": frozenset("QN"),
"frac_aromatic": frozenset("FWY"),
"frac_hydroxy": frozenset("TS"),
"frac_acidic": frozenset("DE"),
}
FEATURE_NAMES = [
"gravy_mean",
"isoelectric_point",
"net_charge_ph7",
"frac_hydrophobic",
"frac_basic",
"frac_sulphur",
"frac_amides",
"frac_aromatic",
"frac_hydroxy",
"frac_acidic",
"shannon_entropy_aa",
]
FEATURE_GROUPS = {
"hydrophobicity": ["gravy_mean", "frac_hydrophobic"],
"isoelectric_point": ["isoelectric_point"],
"net_charge": ["net_charge_ph7"],
"basic_residues": ["frac_basic"],
"sulphur_residues": ["frac_sulphur"],
"amides": ["frac_amides"],
"aromatic_residues": ["frac_aromatic"],
"hydroxy_residues": ["frac_hydroxy"],
"acidic_residues": ["frac_acidic"],
"sequence_complexity": ["shannon_entropy_aa"],
}
# Binary label strings (normalized to lowercase, spaces/dashes -> underscores)
LABEL_TO_CLASS = {
"highfret": 1,
"lowfret": 0,
"highff": 1,
"lowff": 0,
"noff": 0,
"nofret": 0,
"no_fret": 0,
"high_ff": 1,
"low_ff": 0,
}
ID_COLUMNS_BY_SCHEMA = {
"esm": ("variant", "sequence_id", "id"),
"nn": ("sequence_id", "variant", "id"),
}
# Prefer `label` (ESM/RF scripts) over `per_token_model_class` (nn_one_hot only)
LABEL_COLUMN_CANDIDATES = ("label", "per_token_model_class", "target")
def parse_arguments():
parser = argparse.ArgumentParser(
description="Logistic regression baseline on sequence composition features (no embeddings).",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--fasta", dest="FASTA_PATH", required=True, help="FASTA with variant sequences")
parser.add_argument("--csv", dest="CSV_PATH", required=True, help="CSV with variant IDs and FRET labels")
parser.add_argument("--output", dest="OUTPUT_DIR", required=True, help="Directory for outputs")
parser.add_argument(
"--csv-id-col",
dest="CSV_ID_COL",
default=None,
help="CSV column for FASTA header IDs (default: variant, or sequence_id with --csv-schema nn)",
)
parser.add_argument(
"--csv-label-col",
dest="CSV_LABEL_COL",
default=None,
help="CSV label column (default: auto — prefers 'label', then per_token_model_class)",
)
parser.add_argument(
"--csv-schema",
dest="CSV_SCHEMA",
choices=["esm", "nn"],
default="esm",
help="esm: prefer variant ID column; nn: prefer sequence_id. Label column always prefers 'label' if present.",
)
parser.add_argument(
"--seed",
dest="SEED",
type=int,
default=None,
help="Random seed (default: generate and save to output_dir/random_seed.txt)",
)
parser.add_argument(
"--c",
dest="C",
type=float,
default=1.0,
help="Inverse L2 regularization strength for LogisticRegression",
)
parser.add_argument(
"--ablation",
dest="ABLATION",
action="store_true",
help="Also fit logistic models on each feature group alone and report test AUC",
)
parser.add_argument(
"--save-diagnostics",
dest="SAVE_DIAGNOSTICS",
action="store_true",
help="Save ROC plot/CSV, feature table, coefficient CSV, ablation CSV",
)
parser.add_argument(
"--save-features",
dest="SAVE_FEATURES",
action="store_true",
help="Write per-variant composition feature matrix CSV",
)
parser.add_argument(
"--perm-n-repeats",
dest="PERM_N_REPEATS",
type=int,
default=30,
help="Number of shuffle repeats for permutation importance (evaluated on test set)",
)
return parser.parse_args()
def _pick_column(name: str, candidates: Tuple[str, ...], available: List[str]) -> Optional[str]:
if name in available:
return name
for col in candidates:
if col in available:
return col
return None
def parse_binary_label(value) -> Optional[int]:
"""Map CSV cell to 0 (lowFRET) or 1 (highFRET); unknown values return None."""
if value is None:
return None
try:
if pd.isna(value):
return None
except (TypeError, ValueError):
pass
if isinstance(value, (bool, np.bool_)):
return int(value)
if isinstance(value, (int, np.integer)):
if int(value) in (0, 1):
return int(value)
return None
if isinstance(value, (float, np.floating)):
fv = float(value)
if fv in (0.0, 1.0):
return int(fv)
return None
key = str(value).strip().lower().replace(" ", "").replace("-", "_")
return LABEL_TO_CLASS.get(key)
def resolve_csv_columns(args, csv_columns: List[str]) -> Tuple[str, str]:
"""Resolve ID and label columns; label defaults to 'label' when present."""
available = list(csv_columns)
if args.CSV_ID_COL and args.CSV_LABEL_COL:
id_col, label_col = args.CSV_ID_COL, args.CSV_LABEL_COL
else:
id_candidates = ID_COLUMNS_BY_SCHEMA.get(args.CSV_SCHEMA, ID_COLUMNS_BY_SCHEMA["esm"])
id_col = _pick_column("", id_candidates, available)
label_col = _pick_column("", LABEL_COLUMN_CANDIDATES, available)
if not id_col or id_col not in available:
raise ValueError(
f"Could not find ID column in CSV. Tried {ID_COLUMNS_BY_SCHEMA.get(args.CSV_SCHEMA)}; "
f"columns: {available}. Use --csv-id-col."
)
if not label_col or label_col not in available:
raise ValueError(
f"Could not find label column in CSV. Tried {LABEL_COLUMN_CANDIDATES}; "
f"columns: {available}. Use --csv-label-col label"
)
return id_col, label_col
def standard_amino_acids(sequence: str) -> str:
return "".join(aa for aa in str(sequence).upper() if aa in CANONICAL_AAS)
def compute_gravy(sequence: str) -> float:
residues = [HYDROPATHY[aa] for aa in standard_amino_acids(sequence)]
if not residues:
return float("nan")
return float(np.mean(residues))
def compute_isoelectric_point(sequence: str) -> float:
clean = standard_amino_acids(sequence)
if not clean:
return float("nan")
try:
return float(IsoelectricPoint(clean).pi())
except Exception:
return float("nan")
def compute_net_charge_ph7(sequence: str) -> float:
clean = standard_amino_acids(sequence)
if not clean:
return float("nan")
try:
return float(ProteinAnalysis(clean).charge_at_pH(7.0))
except Exception:
n_basic = sum(1 for aa in clean if aa in AA_GROUPS["frac_basic"])
n_acidic = sum(1 for aa in clean if aa in AA_GROUPS["frac_acidic"])
return float(n_basic - n_acidic)
def compute_aa_group_fractions(sequence: str) -> Dict[str, float]:
clean = standard_amino_acids(sequence)
n = len(clean)
if n == 0:
return {name: float("nan") for name in AA_GROUPS}
return {
name: sum(1 for aa in clean if aa in aas) / n
for name, aas in AA_GROUPS.items()
}
def compute_complexity_features(sequence: str) -> Dict[str, float]:
"""Shannon entropy of AA composition (bits)."""
clean = standard_amino_acids(sequence)
n = len(clean)
if n == 0:
return {"shannon_entropy_aa": float("nan")}
counts = Counter(clean)
entropy = 0.0
for count in counts.values():
p = count / n
entropy -= p * math.log2(p)
return {"shannon_entropy_aa": entropy}
def compute_composition_features(sequence: str) -> Dict[str, float]:
row = {
"gravy_mean": compute_gravy(sequence),
"isoelectric_point": compute_isoelectric_point(sequence),
"net_charge_ph7": compute_net_charge_ph7(sequence),
}
row.update(compute_aa_group_fractions(sequence))
row.update(compute_complexity_features(sequence))
return row
def load_labeled_sequences(
fasta_path: str, csv_path: str, id_col: str, label_col: str
) -> Tuple[List[str], List[str], np.ndarray, pd.DataFrame]:
df = pd.read_csv(csv_path, header=0, index_col=None)
if id_col not in df.columns:
raise ValueError(f"CSV missing ID column '{id_col}'. Columns: {list(df.columns)}")
if label_col not in df.columns:
raise ValueError(f"CSV missing label column '{label_col}'. Columns: {list(df.columns)}")
print(f"Label column '{label_col}' value counts in CSV (all rows):")
print(df[label_col].value_counts(dropna=False).head(20).to_string())
variant_ids: List[str] = []
sequences: List[str] = []
labels: List[int] = []
feature_rows: List[Dict[str, float]] = []
skipped_unknown_label = Counter()
for header, seq in esm.data.read_fasta(fasta_path):
variant_id = header[0:]
variant_row = df[df[id_col].astype(str) == str(variant_id)]
if variant_row.empty:
continue
raw_label = variant_row[label_col].iloc[0]
label = parse_binary_label(raw_label)
if label is None:
skipped_unknown_label[str(raw_label)] += 1
continue
feats = compute_composition_features(seq)
variant_ids.append(variant_id)
sequences.append(seq)
labels.append(label)
feature_rows.append(feats)
if skipped_unknown_label:
print(
f"Warning: skipped {sum(skipped_unknown_label.values())} FASTA rows with "
f"unrecognized labels in '{label_col}': {dict(skipped_unknown_label)}"
)
if not variant_ids:
raise RuntimeError(
f"No labeled variants matched between FASTA and CSV using id='{id_col}', "
f"label='{label_col}'. Expected highFRET/lowFRET (or 0/1) in the label column."
)
feature_df = pd.DataFrame(feature_rows)
feature_df.insert(0, "variant", variant_ids)
feature_df.insert(1, "label", labels)
feature_df.insert(2, "sequence_length", [len(standard_amino_acids(s)) for s in sequences])
X = feature_df[FEATURE_NAMES].to_numpy(dtype=float)
y = np.array(labels, dtype=int)
return variant_ids, sequences, y, feature_df, X
def apply_temperature_scaling(y_prob: np.ndarray, temperature: float) -> np.ndarray:
eps = 1e-15
p = np.clip(y_prob, eps, 1 - eps)
logit_p = np.log(p / (1 - p))
p_cal = 1.0 / (1.0 + np.exp(-logit_p / temperature))
return np.clip(p_cal, 0.0, 1.0)
def fit_temperature(y_true: np.ndarray, y_prob: np.ndarray, n_trials: int = 100) -> float:
T_candidates = np.logspace(-2, 1, n_trials)
best_T, best_ece = 1.0, float("inf")
for T in T_candidates:
p_cal = apply_temperature_scaling(y_prob, T)
ece = _compute_ece(y_true, p_cal)
if ece < best_ece:
best_ece, best_T = ece, T
return best_T
def _compute_ece(y_true: np.ndarray, y_prob: np.ndarray, n_bins: int = 10) -> float:
bin_boundaries = np.linspace(0, 1, n_bins + 1)
ece = 0.0
for i in range(n_bins):
in_bin = (y_prob > bin_boundaries[i]) & (y_prob <= bin_boundaries[i + 1])
prop = np.mean(in_bin)
if prop > 0:
ece += prop * abs(np.mean(y_true[in_bin]) - np.mean(y_prob[in_bin]))
return float(ece)
def optimal_f1_threshold(y_true: np.ndarray, y_prob: np.ndarray, n_thresholds: int = 101) -> float:
thresholds = np.linspace(0, 1, n_thresholds)
best_t, best_f1 = 0.5, -1.0
for t in thresholds:
y_pred = (y_prob >= t).astype(int)
tp = np.sum((y_pred == 1) & (y_true == 1))
fp = np.sum((y_pred == 1) & (y_true == 0))
fn = np.sum((y_pred == 0) & (y_true == 1))
p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
if f1 > best_f1:
best_f1, best_t = f1, t
return float(best_t)
def fit_lr_pipeline(
X_train: np.ndarray,
y_train: np.ndarray,
C: float,
seed: int,
) -> Tuple[LogisticRegression, StandardScaler]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_train)
clf = LogisticRegression(
C=C,
penalty="l2",
solver="lbfgs",
max_iter=2000,
random_state=seed,
)
clf.fit(X_scaled, y_train)
return clf, scaler
def predict_proba(
clf: LogisticRegression, scaler: StandardScaler, X: np.ndarray
) -> np.ndarray:
return clf.predict_proba(scaler.transform(X))[:, 1]
def compute_permutation_importance_table(
clf: LogisticRegression,
scaler: StandardScaler,
X: np.ndarray,
y: np.ndarray,
feature_names: List[str],
*,
seed: int,
n_repeats: int = 30,
scoring: str = "roc_auc",
eval_split: str = "test",
) -> pd.DataFrame:
"""
Permutation importance on scaled features (ROC-AUC drop when each column is shuffled).
Higher importance_mean = stronger contribution to discrimination.
"""
X_scaled = scaler.transform(X)
perm = permutation_importance(
clf,
X_scaled,
y,
scoring=scoring,
n_repeats=n_repeats,
random_state=seed,
n_jobs=-1,
)
df = pd.DataFrame(
{
"feature": feature_names,
"importance_mean": perm.importances_mean,
"importance_std": perm.importances_std,
"eval_split": eval_split,
"scoring": scoring,
"n_repeats": n_repeats,
}
)
df = df.sort_values("importance_mean", ascending=False).reset_index(drop=True)
df.insert(0, "rank", np.arange(1, len(df) + 1))
feature_to_group = {
col: group for group, cols in FEATURE_GROUPS.items() for col in cols
}
df["feature_group"] = df["feature"].map(feature_to_group)
return df
def matthews_corrcoef_safe(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""MCC from confusion matrix; nan when undefined (avoids sklearn RuntimeWarning)."""
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
if cm.shape != (2, 2):
return float("nan")
tn, fp, fn, tp = (int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1]))
denom = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
if denom <= 0:
return float("nan")
mcc = (tp * tn - fp * fn) / np.sqrt(float(denom))
return float(mcc) if np.isfinite(mcc) else float("nan")
def evaluate_split(
y_true: np.ndarray,
y_prob: np.ndarray,
threshold: float,
) -> Dict[str, float]:
y_pred = (y_prob >= threshold).astype(int)
try:
auc = float(roc_auc_score(y_true, y_prob))
except ValueError:
auc = float("nan")
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
return {
"roc_auc": auc,
"accuracy": float(accuracy_score(y_true, y_pred)),
"average_precision": float(average_precision_score(y_true, y_prob)),
"mcc": matthews_corrcoef_safe(y_true, y_pred),
"sensitivity": float(tp / (tp + fn)) if (tp + fn) > 0 else 0.0,
"specificity": float(tn / (tn + fp)) if (tn + fp) > 0 else 0.0,
"precision": float(tp / (tp + fp)) if (tp + fp) > 0 else 0.0,
"f1": float(2 * tp / (2 * tp + fp + fn)) if (2 * tp + fp + fn) > 0 else 0.0,
}
def train_eval_composition_model(
X: np.ndarray,
y: np.ndarray,
feature_names: List[str],
C: float,
seed: int,
) -> Dict:
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=0.8, random_state=seed, stratify=y
)
X_fit, X_val, y_fit, y_val = train_test_split(
X_train, y_train, test_size=0.2, random_state=seed, stratify=y_train
)
clf, scaler = fit_lr_pipeline(X_fit, y_fit, C=C, seed=seed)
y_val_proba = predict_proba(clf, scaler, X_val)
temperature = fit_temperature(y_val, y_val_proba)
y_val_cal = apply_temperature_scaling(y_val_proba, temperature)
threshold = optimal_f1_threshold(y_val, y_val_cal)
y_test_proba = apply_temperature_scaling(predict_proba(clf, scaler, X_test), temperature)
test_metrics = evaluate_split(y_test, y_test_proba, threshold)
val_metrics = evaluate_split(y_val, y_val_cal, threshold)
coef = clf.coef_[0]
coef_map = {name: float(c) for name, c in zip(feature_names, coef)}
return {
"clf": clf,
"scaler": scaler,
"feature_names": feature_names,
"temperature": temperature,
"threshold": threshold,
"test_metrics": test_metrics,
"val_metrics": val_metrics,
"coefficients": coef_map,
"y_test": y_test,
"y_val": y_val,
"y_test_proba": y_test_proba,
"X_test": X_test,
}
def save_roc_plot(y_true: np.ndarray, y_prob: np.ndarray, auc: float, path: str, title: str):
fpr, tpr, _ = roc_curve(y_true, y_prob)
fig, ax = plt.subplots(figsize=(5, 5))
ax.plot(fpr, tpr, label=f"ROC (AUC = {auc:.3f})")
ax.plot([0, 1], [0, 1], "k--")
ax.set_xlabel("False positive rate")
ax.set_ylabel("True positive rate")
ax.set_title(title)
ax.legend()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
fig.tight_layout()
fig.savefig(path, dpi=150)
plt.close(fig)
def main():
args = parse_arguments()
os.makedirs(args.OUTPUT_DIR, exist_ok=True)
if args.SEED is not None:
seed = args.SEED
print(f"Using provided random seed: {seed}")
else:
seed = secrets.randbelow(2**32)
with open(os.path.join(args.OUTPUT_DIR, "random_seed.txt"), "w") as f:
f.write(f"{seed}\n")
print(f"Generated random seed: {seed} (saved to {args.OUTPUT_DIR}/random_seed.txt)")
np.random.seed(seed)
csv_preview = pd.read_csv(args.CSV_PATH, nrows=5)
id_col, label_col = resolve_csv_columns(args, list(csv_preview.columns))
print("=" * 80)
print("Composition-only logistic regression baseline (no embeddings)")
print("=" * 80)
print(f"FASTA: {args.FASTA_PATH}")
print(f"CSV: {args.CSV_PATH} (id={id_col}, label={label_col})")
print(f"Output: {args.OUTPUT_DIR}")
print(f"Features ({len(FEATURE_NAMES)}): {', '.join(FEATURE_NAMES)}")
variant_ids, _sequences, y, feature_df, X = load_labeled_sequences(
args.FASTA_PATH, args.CSV_PATH, id_col, label_col
)
print(f"\nLoaded {len(variant_ids)} variants")
class_counts = Counter(y)
print(f"Class distribution (0=lowFRET, 1=highFRET): {class_counts}")
if len(class_counts) < 2:
raise ValueError(
f"Need both classes for logistic regression, but only found: {class_counts}. "
f"Check that --csv-label-col points to the column with highFRET and lowFRET "
f"(currently using '{label_col}'). For FF_all_rounds CSV use "
f"--csv-label-col label --csv-id-col variant (default with --csv-schema esm)."
)
n_nan = int(np.isnan(X).any(axis=1).sum())
if n_nan:
print(f"Warning: dropping {n_nan} variants with NaN composition features")
valid = ~np.isnan(X).any(axis=1)
X = X[valid]
y = y[valid]
feature_df = feature_df.loc[valid].reset_index(drop=True)
variant_ids = [v for v, keep in zip(variant_ids, valid) if keep]
if args.SAVE_FEATURES or args.SAVE_DIAGNOSTICS:
feature_df.to_csv(os.path.join(args.OUTPUT_DIR, "composition_features_per_variant.csv"), index=False)
print("\nTraining full composition logistic regression...")
result = train_eval_composition_model(X, y, FEATURE_NAMES, C=args.C, seed=seed)
test_auc = result["test_metrics"]["roc_auc"]
val_auc = result["val_metrics"]["roc_auc"]
print("\n" + "=" * 80)
print("PRIMARY RESULT — held-out test set")
print("=" * 80)
print(f"Test ROC-AUC (composition baseline): {test_auc:.4f}")
print(f"Validation ROC-AUC: {val_auc:.4f}")
if test_auc >= 0.85:
print(
"\n*** Test AUC >= 0.85: strong composition-only signal. "
"ESM-2 gains may largely reflect amino-acid composition / biophysical "
"properties rather than learned sequence context beyond these features. ***"
)
else:
print(
"\nTest AUC < 0.85: composition features alone do not match typical "
"strong ESM performance; embeddings may carry additional predictive signal."
)
print("\nTest metrics:")
for k, v in result["test_metrics"].items():
print(f" {k}: {v:.4f}")
print("\nStandardized logistic coefficients (positive class = highFRET):")
for name, coef in sorted(result["coefficients"].items(), key=lambda x: -abs(x[1])):
print(f" {name}: {coef:+.4f}")
print(
f"\nPermutation importance (test set, scoring=roc_auc, "
f"n_repeats={args.PERM_N_REPEATS})..."
)
perm_df = compute_permutation_importance_table(
result["clf"],
result["scaler"],
result["X_test"],
result["y_test"],
FEATURE_NAMES,
seed=seed,
n_repeats=args.PERM_N_REPEATS,
eval_split="test",
)
perm_path = os.path.join(args.OUTPUT_DIR, "composition_lr_permutation_importance.csv")
perm_df.to_csv(perm_path, index=False)
print("Top features by mean ROC-AUC decrease when permuted:")
for _, row in perm_df.head(len(FEATURE_NAMES)).iterrows():
print(
f" {int(row['rank'])}. {row['feature']}: "
f"{row['importance_mean']:.4f} ± {row['importance_std']:.4f}"
)
print("\nClassification report (test):")
y_pred = (result["y_test_proba"] >= result["threshold"]).astype(int)
print(classification_report(result["y_test"], y_pred, zero_division=0))
summary = pd.DataFrame(
[
{
"model": "composition_lr_full",
"n_features": len(FEATURE_NAMES),
"features": ";".join(FEATURE_NAMES),
"test_roc_auc": test_auc,
"val_roc_auc": val_auc,
"test_accuracy": result["test_metrics"]["accuracy"],
"test_f1": result["test_metrics"]["f1"],
"test_mcc": result["test_metrics"]["mcc"],
"temperature": result["temperature"],
"optimal_threshold_f1": result["threshold"],
"C": args.C,
"n_variants": len(y),
"composition_auc_ge_0_85": test_auc >= 0.85,
}
]
)
summary.to_csv(os.path.join(args.OUTPUT_DIR, "composition_baseline_summary.csv"), index=False)
pd.DataFrame({"auc_score": [test_auc]}).to_csv(
os.path.join(args.OUTPUT_DIR, "composition_lr_auc_score.csv"), index=False
)
pd.DataFrame([result["coefficients"]]).to_csv(
os.path.join(args.OUTPUT_DIR, "composition_lr_coefficients.csv"), index=False
)
pd.DataFrame(
{
"temperature": [result["temperature"]],
"optimal_threshold_f1": [result["threshold"]],
}
).to_csv(
os.path.join(args.OUTPUT_DIR, "composition_lr_validation_diagnostics_summary.csv"),
index=False,
)
pd.DataFrame({"temperature": [result["temperature"]]}).to_csv(
os.path.join(args.OUTPUT_DIR, "composition_lr_validation_temperature.csv"),
index=False,
)
model_path = os.path.join(args.OUTPUT_DIR, "composition_lr_model.joblib")
joblib.dump(
{
"clf": result["clf"],
"scaler": result["scaler"],
"feature_names": FEATURE_NAMES,
"C": args.C,
},
model_path,
)
print(f"Inference model saved: {model_path}")
if args.SAVE_DIAGNOSTICS:
fpr, tpr, thresholds = roc_curve(result["y_test"], result["y_test_proba"])
pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": thresholds}).to_csv(
os.path.join(args.OUTPUT_DIR, "composition_lr_roc_curve.csv"), index=False
)
save_roc_plot(
result["y_test"],
result["y_test_proba"],
test_auc,
os.path.join(args.OUTPUT_DIR, "composition_lr_roc_curve.png"),
"Composition LR (test)",
)
ablation_rows = []
if args.ABLATION:
print("\n" + "-" * 80)
print("Ablation: test ROC-AUC by feature group")
print("-" * 80)
for group_name, cols in FEATURE_GROUPS.items():
col_idx = [FEATURE_NAMES.index(c) for c in cols]
X_sub = X[:, col_idx]
sub_result = train_eval_composition_model(X_sub, y, cols, C=args.C, seed=seed)
auc_sub = sub_result["test_metrics"]["roc_auc"]
print(f" {group_name}: AUC = {auc_sub:.4f} ({', '.join(cols)})")
ablation_rows.append(
{
"feature_group": group_name,
"features": ";".join(cols),
"test_roc_auc": auc_sub,
"val_roc_auc": sub_result["val_metrics"]["roc_auc"],
}
)
pd.DataFrame(ablation_rows).to_csv(
os.path.join(args.OUTPUT_DIR, "composition_baseline_ablation.csv"), index=False
)
meta = {
"model": "sklearn LogisticRegression (L2)",
"features": FEATURE_NAMES,
"feature_groups": FEATURE_GROUPS,
"split": "80% train / 20% test; 20% of train for validation calibration",
"test_roc_auc": test_auc,
"interpretation_threshold": 0.85,
"csv_id_col": id_col,
"csv_label_col": label_col,
}
with open(os.path.join(args.OUTPUT_DIR, "composition_baseline_metadata.json"), "w") as f:
json.dump(meta, f, indent=2)
print(f"\nArtifacts written to '{args.OUTPUT_DIR}/':")
print(" - composition_baseline_summary.csv (primary AUC + interpretation flag)")
print(" - composition_lr_model.joblib (clf + scaler for evaluate_lr_sequence_composition.py)")
print(" - composition_lr_auc_score.csv")
print(" - composition_lr_coefficients.csv")
print(" - composition_lr_permutation_importance.csv")
print(" - composition_lr_validation_diagnostics_summary.csv")
print(" - composition_lr_validation_temperature.csv")
if args.ABLATION:
print(" - composition_baseline_ablation.csv")
if args.SAVE_FEATURES or args.SAVE_DIAGNOSTICS:
print(" - composition_features_per_variant.csv")
if args.SAVE_DIAGNOSTICS:
print(" - composition_lr_roc_curve.csv / .png")
if __name__ == "__main__":
main()