| |
| """ |
| Composition-only baseline for FRET binary classification (Figure 3). |
| |
| Logistic regression on simple sequence-derived features — no ESM embeddings: |
| - Hydrophobicity (mean Kyte-Doolittle GRAVY) |
| - Isoelectric point (pI, Biopython) |
| - Net charge at pH 7 (Biopython) |
| - Amino-acid group fractions (hydrophobic, basic, sulphur, amides, aromatic, hydroxy, acidic) |
| - Sequence complexity (Shannon entropy of amino-acid composition) |
| |
| Uses the same stratified 80/20 train/test split and validation protocol as the |
| other model scripts (temperature scaling + F1-optimal threshold on validation; |
| test ROC-AUC reported on held-out set). |
| |
| If test ROC-AUC >= 0.85, composition alone may explain much of the ESM signal. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import secrets |
| from collections import Counter |
| from typing import Dict, List, Optional, Tuple |
|
|
| import esm |
| import joblib |
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.metrics import ( |
| accuracy_score, |
| average_precision_score, |
| classification_report, |
| confusion_matrix, |
| roc_auc_score, |
| roc_curve, |
| ) |
| from sklearn.inspection import permutation_importance |
| from sklearn.model_selection import train_test_split |
| from sklearn.preprocessing import StandardScaler |
|
|
| try: |
| from Bio.SeqUtils.IsoelectricPoint import IsoelectricPoint |
| from Bio.SeqUtils.ProtParam import ProteinAnalysis |
| except ImportError as exc: |
| raise ImportError( |
| "This script requires Biopython (pip install biopython). " |
| "It is used for isoelectric point and charge-at-pH features." |
| ) from exc |
|
|
| |
| HYDROPATHY = { |
| "A": 1.8, |
| "C": 2.5, |
| "D": -3.5, |
| "E": -3.5, |
| "F": 2.8, |
| "G": -0.4, |
| "H": -3.2, |
| "I": 4.5, |
| "K": -3.9, |
| "L": 3.8, |
| "M": 1.9, |
| "N": -3.5, |
| "P": -1.6, |
| "Q": -3.5, |
| "R": -4.5, |
| "S": -0.8, |
| "T": -0.7, |
| "V": 4.2, |
| "W": -0.9, |
| "Y": -1.3, |
| } |
| CANONICAL_AAS = set(HYDROPATHY) |
|
|
| AA_GROUPS = { |
| "frac_hydrophobic": frozenset("LVIGAP"), |
| "frac_basic": frozenset("RKH"), |
| "frac_sulphur": frozenset("MC"), |
| "frac_amides": frozenset("QN"), |
| "frac_aromatic": frozenset("FWY"), |
| "frac_hydroxy": frozenset("TS"), |
| "frac_acidic": frozenset("DE"), |
| } |
|
|
| FEATURE_NAMES = [ |
| "gravy_mean", |
| "isoelectric_point", |
| "net_charge_ph7", |
| "frac_hydrophobic", |
| "frac_basic", |
| "frac_sulphur", |
| "frac_amides", |
| "frac_aromatic", |
| "frac_hydroxy", |
| "frac_acidic", |
| "shannon_entropy_aa", |
| ] |
|
|
| FEATURE_GROUPS = { |
| "hydrophobicity": ["gravy_mean", "frac_hydrophobic"], |
| "isoelectric_point": ["isoelectric_point"], |
| "net_charge": ["net_charge_ph7"], |
| "basic_residues": ["frac_basic"], |
| "sulphur_residues": ["frac_sulphur"], |
| "amides": ["frac_amides"], |
| "aromatic_residues": ["frac_aromatic"], |
| "hydroxy_residues": ["frac_hydroxy"], |
| "acidic_residues": ["frac_acidic"], |
| "sequence_complexity": ["shannon_entropy_aa"], |
| } |
|
|
| |
| LABEL_TO_CLASS = { |
| "highfret": 1, |
| "lowfret": 0, |
| "highff": 1, |
| "lowff": 0, |
| "noff": 0, |
| "nofret": 0, |
| "no_fret": 0, |
| "high_ff": 1, |
| "low_ff": 0, |
| } |
|
|
| ID_COLUMNS_BY_SCHEMA = { |
| "esm": ("variant", "sequence_id", "id"), |
| "nn": ("sequence_id", "variant", "id"), |
| } |
| |
| LABEL_COLUMN_CANDIDATES = ("label", "per_token_model_class", "target") |
|
|
|
|
| def parse_arguments(): |
| parser = argparse.ArgumentParser( |
| description="Logistic regression baseline on sequence composition features (no embeddings).", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
| parser.add_argument("--fasta", dest="FASTA_PATH", required=True, help="FASTA with variant sequences") |
| parser.add_argument("--csv", dest="CSV_PATH", required=True, help="CSV with variant IDs and FRET labels") |
| parser.add_argument("--output", dest="OUTPUT_DIR", required=True, help="Directory for outputs") |
| parser.add_argument( |
| "--csv-id-col", |
| dest="CSV_ID_COL", |
| default=None, |
| help="CSV column for FASTA header IDs (default: variant, or sequence_id with --csv-schema nn)", |
| ) |
| parser.add_argument( |
| "--csv-label-col", |
| dest="CSV_LABEL_COL", |
| default=None, |
| help="CSV label column (default: auto — prefers 'label', then per_token_model_class)", |
| ) |
| parser.add_argument( |
| "--csv-schema", |
| dest="CSV_SCHEMA", |
| choices=["esm", "nn"], |
| default="esm", |
| help="esm: prefer variant ID column; nn: prefer sequence_id. Label column always prefers 'label' if present.", |
| ) |
| parser.add_argument( |
| "--seed", |
| dest="SEED", |
| type=int, |
| default=None, |
| help="Random seed (default: generate and save to output_dir/random_seed.txt)", |
| ) |
| parser.add_argument( |
| "--c", |
| dest="C", |
| type=float, |
| default=1.0, |
| help="Inverse L2 regularization strength for LogisticRegression", |
| ) |
| parser.add_argument( |
| "--ablation", |
| dest="ABLATION", |
| action="store_true", |
| help="Also fit logistic models on each feature group alone and report test AUC", |
| ) |
| parser.add_argument( |
| "--save-diagnostics", |
| dest="SAVE_DIAGNOSTICS", |
| action="store_true", |
| help="Save ROC plot/CSV, feature table, coefficient CSV, ablation CSV", |
| ) |
| parser.add_argument( |
| "--save-features", |
| dest="SAVE_FEATURES", |
| action="store_true", |
| help="Write per-variant composition feature matrix CSV", |
| ) |
| parser.add_argument( |
| "--perm-n-repeats", |
| dest="PERM_N_REPEATS", |
| type=int, |
| default=30, |
| help="Number of shuffle repeats for permutation importance (evaluated on test set)", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def _pick_column(name: str, candidates: Tuple[str, ...], available: List[str]) -> Optional[str]: |
| if name in available: |
| return name |
| for col in candidates: |
| if col in available: |
| return col |
| return None |
|
|
|
|
| def parse_binary_label(value) -> Optional[int]: |
| """Map CSV cell to 0 (lowFRET) or 1 (highFRET); unknown values return None.""" |
| if value is None: |
| return None |
| try: |
| if pd.isna(value): |
| return None |
| except (TypeError, ValueError): |
| pass |
| if isinstance(value, (bool, np.bool_)): |
| return int(value) |
| if isinstance(value, (int, np.integer)): |
| if int(value) in (0, 1): |
| return int(value) |
| return None |
| if isinstance(value, (float, np.floating)): |
| fv = float(value) |
| if fv in (0.0, 1.0): |
| return int(fv) |
| return None |
| key = str(value).strip().lower().replace(" ", "").replace("-", "_") |
| return LABEL_TO_CLASS.get(key) |
|
|
|
|
| def resolve_csv_columns(args, csv_columns: List[str]) -> Tuple[str, str]: |
| """Resolve ID and label columns; label defaults to 'label' when present.""" |
| available = list(csv_columns) |
| if args.CSV_ID_COL and args.CSV_LABEL_COL: |
| id_col, label_col = args.CSV_ID_COL, args.CSV_LABEL_COL |
| else: |
| id_candidates = ID_COLUMNS_BY_SCHEMA.get(args.CSV_SCHEMA, ID_COLUMNS_BY_SCHEMA["esm"]) |
| id_col = _pick_column("", id_candidates, available) |
| label_col = _pick_column("", LABEL_COLUMN_CANDIDATES, available) |
| if not id_col or id_col not in available: |
| raise ValueError( |
| f"Could not find ID column in CSV. Tried {ID_COLUMNS_BY_SCHEMA.get(args.CSV_SCHEMA)}; " |
| f"columns: {available}. Use --csv-id-col." |
| ) |
| if not label_col or label_col not in available: |
| raise ValueError( |
| f"Could not find label column in CSV. Tried {LABEL_COLUMN_CANDIDATES}; " |
| f"columns: {available}. Use --csv-label-col label" |
| ) |
| return id_col, label_col |
|
|
|
|
| def standard_amino_acids(sequence: str) -> str: |
| return "".join(aa for aa in str(sequence).upper() if aa in CANONICAL_AAS) |
|
|
|
|
| def compute_gravy(sequence: str) -> float: |
| residues = [HYDROPATHY[aa] for aa in standard_amino_acids(sequence)] |
| if not residues: |
| return float("nan") |
| return float(np.mean(residues)) |
|
|
|
|
| def compute_isoelectric_point(sequence: str) -> float: |
| clean = standard_amino_acids(sequence) |
| if not clean: |
| return float("nan") |
| try: |
| return float(IsoelectricPoint(clean).pi()) |
| except Exception: |
| return float("nan") |
|
|
|
|
| def compute_net_charge_ph7(sequence: str) -> float: |
| clean = standard_amino_acids(sequence) |
| if not clean: |
| return float("nan") |
| try: |
| return float(ProteinAnalysis(clean).charge_at_pH(7.0)) |
| except Exception: |
| n_basic = sum(1 for aa in clean if aa in AA_GROUPS["frac_basic"]) |
| n_acidic = sum(1 for aa in clean if aa in AA_GROUPS["frac_acidic"]) |
| return float(n_basic - n_acidic) |
|
|
|
|
| def compute_aa_group_fractions(sequence: str) -> Dict[str, float]: |
| clean = standard_amino_acids(sequence) |
| n = len(clean) |
| if n == 0: |
| return {name: float("nan") for name in AA_GROUPS} |
| return { |
| name: sum(1 for aa in clean if aa in aas) / n |
| for name, aas in AA_GROUPS.items() |
| } |
|
|
|
|
| def compute_complexity_features(sequence: str) -> Dict[str, float]: |
| """Shannon entropy of AA composition (bits).""" |
| clean = standard_amino_acids(sequence) |
| n = len(clean) |
| if n == 0: |
| return {"shannon_entropy_aa": float("nan")} |
| counts = Counter(clean) |
| entropy = 0.0 |
| for count in counts.values(): |
| p = count / n |
| entropy -= p * math.log2(p) |
| return {"shannon_entropy_aa": entropy} |
|
|
|
|
| def compute_composition_features(sequence: str) -> Dict[str, float]: |
| row = { |
| "gravy_mean": compute_gravy(sequence), |
| "isoelectric_point": compute_isoelectric_point(sequence), |
| "net_charge_ph7": compute_net_charge_ph7(sequence), |
| } |
| row.update(compute_aa_group_fractions(sequence)) |
| row.update(compute_complexity_features(sequence)) |
| return row |
|
|
|
|
| def load_labeled_sequences( |
| fasta_path: str, csv_path: str, id_col: str, label_col: str |
| ) -> Tuple[List[str], List[str], np.ndarray, pd.DataFrame]: |
| df = pd.read_csv(csv_path, header=0, index_col=None) |
| if id_col not in df.columns: |
| raise ValueError(f"CSV missing ID column '{id_col}'. Columns: {list(df.columns)}") |
| if label_col not in df.columns: |
| raise ValueError(f"CSV missing label column '{label_col}'. Columns: {list(df.columns)}") |
|
|
| print(f"Label column '{label_col}' value counts in CSV (all rows):") |
| print(df[label_col].value_counts(dropna=False).head(20).to_string()) |
|
|
| variant_ids: List[str] = [] |
| sequences: List[str] = [] |
| labels: List[int] = [] |
| feature_rows: List[Dict[str, float]] = [] |
| skipped_unknown_label = Counter() |
|
|
| for header, seq in esm.data.read_fasta(fasta_path): |
| variant_id = header[0:] |
| variant_row = df[df[id_col].astype(str) == str(variant_id)] |
| if variant_row.empty: |
| continue |
| raw_label = variant_row[label_col].iloc[0] |
| label = parse_binary_label(raw_label) |
| if label is None: |
| skipped_unknown_label[str(raw_label)] += 1 |
| continue |
| feats = compute_composition_features(seq) |
| variant_ids.append(variant_id) |
| sequences.append(seq) |
| labels.append(label) |
| feature_rows.append(feats) |
|
|
| if skipped_unknown_label: |
| print( |
| f"Warning: skipped {sum(skipped_unknown_label.values())} FASTA rows with " |
| f"unrecognized labels in '{label_col}': {dict(skipped_unknown_label)}" |
| ) |
|
|
| if not variant_ids: |
| raise RuntimeError( |
| f"No labeled variants matched between FASTA and CSV using id='{id_col}', " |
| f"label='{label_col}'. Expected highFRET/lowFRET (or 0/1) in the label column." |
| ) |
|
|
| feature_df = pd.DataFrame(feature_rows) |
| feature_df.insert(0, "variant", variant_ids) |
| feature_df.insert(1, "label", labels) |
| feature_df.insert(2, "sequence_length", [len(standard_amino_acids(s)) for s in sequences]) |
|
|
| X = feature_df[FEATURE_NAMES].to_numpy(dtype=float) |
| y = np.array(labels, dtype=int) |
| return variant_ids, sequences, y, feature_df, X |
|
|
|
|
| def apply_temperature_scaling(y_prob: np.ndarray, temperature: float) -> np.ndarray: |
| eps = 1e-15 |
| p = np.clip(y_prob, eps, 1 - eps) |
| logit_p = np.log(p / (1 - p)) |
| p_cal = 1.0 / (1.0 + np.exp(-logit_p / temperature)) |
| return np.clip(p_cal, 0.0, 1.0) |
|
|
|
|
| def fit_temperature(y_true: np.ndarray, y_prob: np.ndarray, n_trials: int = 100) -> float: |
| T_candidates = np.logspace(-2, 1, n_trials) |
| best_T, best_ece = 1.0, float("inf") |
| for T in T_candidates: |
| p_cal = apply_temperature_scaling(y_prob, T) |
| ece = _compute_ece(y_true, p_cal) |
| if ece < best_ece: |
| best_ece, best_T = ece, T |
| return best_T |
|
|
|
|
| def _compute_ece(y_true: np.ndarray, y_prob: np.ndarray, n_bins: int = 10) -> float: |
| bin_boundaries = np.linspace(0, 1, n_bins + 1) |
| ece = 0.0 |
| for i in range(n_bins): |
| in_bin = (y_prob > bin_boundaries[i]) & (y_prob <= bin_boundaries[i + 1]) |
| prop = np.mean(in_bin) |
| if prop > 0: |
| ece += prop * abs(np.mean(y_true[in_bin]) - np.mean(y_prob[in_bin])) |
| return float(ece) |
|
|
|
|
| def optimal_f1_threshold(y_true: np.ndarray, y_prob: np.ndarray, n_thresholds: int = 101) -> float: |
| thresholds = np.linspace(0, 1, n_thresholds) |
| best_t, best_f1 = 0.5, -1.0 |
| for t in thresholds: |
| y_pred = (y_prob >= t).astype(int) |
| tp = np.sum((y_pred == 1) & (y_true == 1)) |
| fp = np.sum((y_pred == 1) & (y_true == 0)) |
| fn = np.sum((y_pred == 0) & (y_true == 1)) |
| p = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| r = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0 |
| if f1 > best_f1: |
| best_f1, best_t = f1, t |
| return float(best_t) |
|
|
|
|
| def fit_lr_pipeline( |
| X_train: np.ndarray, |
| y_train: np.ndarray, |
| C: float, |
| seed: int, |
| ) -> Tuple[LogisticRegression, StandardScaler]: |
| scaler = StandardScaler() |
| X_scaled = scaler.fit_transform(X_train) |
| clf = LogisticRegression( |
| C=C, |
| penalty="l2", |
| solver="lbfgs", |
| max_iter=2000, |
| random_state=seed, |
| ) |
| clf.fit(X_scaled, y_train) |
| return clf, scaler |
|
|
|
|
| def predict_proba( |
| clf: LogisticRegression, scaler: StandardScaler, X: np.ndarray |
| ) -> np.ndarray: |
| return clf.predict_proba(scaler.transform(X))[:, 1] |
|
|
|
|
| def compute_permutation_importance_table( |
| clf: LogisticRegression, |
| scaler: StandardScaler, |
| X: np.ndarray, |
| y: np.ndarray, |
| feature_names: List[str], |
| *, |
| seed: int, |
| n_repeats: int = 30, |
| scoring: str = "roc_auc", |
| eval_split: str = "test", |
| ) -> pd.DataFrame: |
| """ |
| Permutation importance on scaled features (ROC-AUC drop when each column is shuffled). |
| Higher importance_mean = stronger contribution to discrimination. |
| """ |
| X_scaled = scaler.transform(X) |
| perm = permutation_importance( |
| clf, |
| X_scaled, |
| y, |
| scoring=scoring, |
| n_repeats=n_repeats, |
| random_state=seed, |
| n_jobs=-1, |
| ) |
| df = pd.DataFrame( |
| { |
| "feature": feature_names, |
| "importance_mean": perm.importances_mean, |
| "importance_std": perm.importances_std, |
| "eval_split": eval_split, |
| "scoring": scoring, |
| "n_repeats": n_repeats, |
| } |
| ) |
| df = df.sort_values("importance_mean", ascending=False).reset_index(drop=True) |
| df.insert(0, "rank", np.arange(1, len(df) + 1)) |
| feature_to_group = { |
| col: group for group, cols in FEATURE_GROUPS.items() for col in cols |
| } |
| df["feature_group"] = df["feature"].map(feature_to_group) |
| return df |
|
|
|
|
| def matthews_corrcoef_safe(y_true: np.ndarray, y_pred: np.ndarray) -> float: |
| """MCC from confusion matrix; nan when undefined (avoids sklearn RuntimeWarning).""" |
| cm = confusion_matrix(y_true, y_pred, labels=[0, 1]) |
| if cm.shape != (2, 2): |
| return float("nan") |
| tn, fp, fn, tp = (int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])) |
| denom = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) |
| if denom <= 0: |
| return float("nan") |
| mcc = (tp * tn - fp * fn) / np.sqrt(float(denom)) |
| return float(mcc) if np.isfinite(mcc) else float("nan") |
|
|
|
|
| def evaluate_split( |
| y_true: np.ndarray, |
| y_prob: np.ndarray, |
| threshold: float, |
| ) -> Dict[str, float]: |
| y_pred = (y_prob >= threshold).astype(int) |
| try: |
| auc = float(roc_auc_score(y_true, y_prob)) |
| except ValueError: |
| auc = float("nan") |
| cm = confusion_matrix(y_true, y_pred, labels=[0, 1]) |
| tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0) |
| return { |
| "roc_auc": auc, |
| "accuracy": float(accuracy_score(y_true, y_pred)), |
| "average_precision": float(average_precision_score(y_true, y_prob)), |
| "mcc": matthews_corrcoef_safe(y_true, y_pred), |
| "sensitivity": float(tp / (tp + fn)) if (tp + fn) > 0 else 0.0, |
| "specificity": float(tn / (tn + fp)) if (tn + fp) > 0 else 0.0, |
| "precision": float(tp / (tp + fp)) if (tp + fp) > 0 else 0.0, |
| "f1": float(2 * tp / (2 * tp + fp + fn)) if (2 * tp + fp + fn) > 0 else 0.0, |
| } |
|
|
|
|
| def train_eval_composition_model( |
| X: np.ndarray, |
| y: np.ndarray, |
| feature_names: List[str], |
| C: float, |
| seed: int, |
| ) -> Dict: |
| X_train, X_test, y_train, y_test = train_test_split( |
| X, y, train_size=0.8, random_state=seed, stratify=y |
| ) |
| X_fit, X_val, y_fit, y_val = train_test_split( |
| X_train, y_train, test_size=0.2, random_state=seed, stratify=y_train |
| ) |
|
|
| clf, scaler = fit_lr_pipeline(X_fit, y_fit, C=C, seed=seed) |
| y_val_proba = predict_proba(clf, scaler, X_val) |
| temperature = fit_temperature(y_val, y_val_proba) |
| y_val_cal = apply_temperature_scaling(y_val_proba, temperature) |
| threshold = optimal_f1_threshold(y_val, y_val_cal) |
|
|
| y_test_proba = apply_temperature_scaling(predict_proba(clf, scaler, X_test), temperature) |
| test_metrics = evaluate_split(y_test, y_test_proba, threshold) |
| val_metrics = evaluate_split(y_val, y_val_cal, threshold) |
|
|
| coef = clf.coef_[0] |
| coef_map = {name: float(c) for name, c in zip(feature_names, coef)} |
|
|
| return { |
| "clf": clf, |
| "scaler": scaler, |
| "feature_names": feature_names, |
| "temperature": temperature, |
| "threshold": threshold, |
| "test_metrics": test_metrics, |
| "val_metrics": val_metrics, |
| "coefficients": coef_map, |
| "y_test": y_test, |
| "y_val": y_val, |
| "y_test_proba": y_test_proba, |
| "X_test": X_test, |
| } |
|
|
|
|
| def save_roc_plot(y_true: np.ndarray, y_prob: np.ndarray, auc: float, path: str, title: str): |
| fpr, tpr, _ = roc_curve(y_true, y_prob) |
| fig, ax = plt.subplots(figsize=(5, 5)) |
| ax.plot(fpr, tpr, label=f"ROC (AUC = {auc:.3f})") |
| ax.plot([0, 1], [0, 1], "k--") |
| ax.set_xlabel("False positive rate") |
| ax.set_ylabel("True positive rate") |
| ax.set_title(title) |
| ax.legend() |
| ax.set_xlim(0, 1) |
| ax.set_ylim(0, 1) |
| fig.tight_layout() |
| fig.savefig(path, dpi=150) |
| plt.close(fig) |
|
|
|
|
| def main(): |
| args = parse_arguments() |
| os.makedirs(args.OUTPUT_DIR, exist_ok=True) |
|
|
| if args.SEED is not None: |
| seed = args.SEED |
| print(f"Using provided random seed: {seed}") |
| else: |
| seed = secrets.randbelow(2**32) |
| with open(os.path.join(args.OUTPUT_DIR, "random_seed.txt"), "w") as f: |
| f.write(f"{seed}\n") |
| print(f"Generated random seed: {seed} (saved to {args.OUTPUT_DIR}/random_seed.txt)") |
| np.random.seed(seed) |
|
|
| csv_preview = pd.read_csv(args.CSV_PATH, nrows=5) |
| id_col, label_col = resolve_csv_columns(args, list(csv_preview.columns)) |
| print("=" * 80) |
| print("Composition-only logistic regression baseline (no embeddings)") |
| print("=" * 80) |
| print(f"FASTA: {args.FASTA_PATH}") |
| print(f"CSV: {args.CSV_PATH} (id={id_col}, label={label_col})") |
| print(f"Output: {args.OUTPUT_DIR}") |
| print(f"Features ({len(FEATURE_NAMES)}): {', '.join(FEATURE_NAMES)}") |
|
|
| variant_ids, _sequences, y, feature_df, X = load_labeled_sequences( |
| args.FASTA_PATH, args.CSV_PATH, id_col, label_col |
| ) |
| print(f"\nLoaded {len(variant_ids)} variants") |
| class_counts = Counter(y) |
| print(f"Class distribution (0=lowFRET, 1=highFRET): {class_counts}") |
| if len(class_counts) < 2: |
| raise ValueError( |
| f"Need both classes for logistic regression, but only found: {class_counts}. " |
| f"Check that --csv-label-col points to the column with highFRET and lowFRET " |
| f"(currently using '{label_col}'). For FF_all_rounds CSV use " |
| f"--csv-label-col label --csv-id-col variant (default with --csv-schema esm)." |
| ) |
|
|
| n_nan = int(np.isnan(X).any(axis=1).sum()) |
| if n_nan: |
| print(f"Warning: dropping {n_nan} variants with NaN composition features") |
| valid = ~np.isnan(X).any(axis=1) |
| X = X[valid] |
| y = y[valid] |
| feature_df = feature_df.loc[valid].reset_index(drop=True) |
| variant_ids = [v for v, keep in zip(variant_ids, valid) if keep] |
|
|
| if args.SAVE_FEATURES or args.SAVE_DIAGNOSTICS: |
| feature_df.to_csv(os.path.join(args.OUTPUT_DIR, "composition_features_per_variant.csv"), index=False) |
|
|
| print("\nTraining full composition logistic regression...") |
| result = train_eval_composition_model(X, y, FEATURE_NAMES, C=args.C, seed=seed) |
| test_auc = result["test_metrics"]["roc_auc"] |
| val_auc = result["val_metrics"]["roc_auc"] |
|
|
| print("\n" + "=" * 80) |
| print("PRIMARY RESULT — held-out test set") |
| print("=" * 80) |
| print(f"Test ROC-AUC (composition baseline): {test_auc:.4f}") |
| print(f"Validation ROC-AUC: {val_auc:.4f}") |
| if test_auc >= 0.85: |
| print( |
| "\n*** Test AUC >= 0.85: strong composition-only signal. " |
| "ESM-2 gains may largely reflect amino-acid composition / biophysical " |
| "properties rather than learned sequence context beyond these features. ***" |
| ) |
| else: |
| print( |
| "\nTest AUC < 0.85: composition features alone do not match typical " |
| "strong ESM performance; embeddings may carry additional predictive signal." |
| ) |
|
|
| print("\nTest metrics:") |
| for k, v in result["test_metrics"].items(): |
| print(f" {k}: {v:.4f}") |
| print("\nStandardized logistic coefficients (positive class = highFRET):") |
| for name, coef in sorted(result["coefficients"].items(), key=lambda x: -abs(x[1])): |
| print(f" {name}: {coef:+.4f}") |
|
|
| print( |
| f"\nPermutation importance (test set, scoring=roc_auc, " |
| f"n_repeats={args.PERM_N_REPEATS})..." |
| ) |
| perm_df = compute_permutation_importance_table( |
| result["clf"], |
| result["scaler"], |
| result["X_test"], |
| result["y_test"], |
| FEATURE_NAMES, |
| seed=seed, |
| n_repeats=args.PERM_N_REPEATS, |
| eval_split="test", |
| ) |
| perm_path = os.path.join(args.OUTPUT_DIR, "composition_lr_permutation_importance.csv") |
| perm_df.to_csv(perm_path, index=False) |
| print("Top features by mean ROC-AUC decrease when permuted:") |
| for _, row in perm_df.head(len(FEATURE_NAMES)).iterrows(): |
| print( |
| f" {int(row['rank'])}. {row['feature']}: " |
| f"{row['importance_mean']:.4f} ± {row['importance_std']:.4f}" |
| ) |
|
|
| print("\nClassification report (test):") |
| y_pred = (result["y_test_proba"] >= result["threshold"]).astype(int) |
| print(classification_report(result["y_test"], y_pred, zero_division=0)) |
|
|
| summary = pd.DataFrame( |
| [ |
| { |
| "model": "composition_lr_full", |
| "n_features": len(FEATURE_NAMES), |
| "features": ";".join(FEATURE_NAMES), |
| "test_roc_auc": test_auc, |
| "val_roc_auc": val_auc, |
| "test_accuracy": result["test_metrics"]["accuracy"], |
| "test_f1": result["test_metrics"]["f1"], |
| "test_mcc": result["test_metrics"]["mcc"], |
| "temperature": result["temperature"], |
| "optimal_threshold_f1": result["threshold"], |
| "C": args.C, |
| "n_variants": len(y), |
| "composition_auc_ge_0_85": test_auc >= 0.85, |
| } |
| ] |
| ) |
| summary.to_csv(os.path.join(args.OUTPUT_DIR, "composition_baseline_summary.csv"), index=False) |
| pd.DataFrame({"auc_score": [test_auc]}).to_csv( |
| os.path.join(args.OUTPUT_DIR, "composition_lr_auc_score.csv"), index=False |
| ) |
| pd.DataFrame([result["coefficients"]]).to_csv( |
| os.path.join(args.OUTPUT_DIR, "composition_lr_coefficients.csv"), index=False |
| ) |
| pd.DataFrame( |
| { |
| "temperature": [result["temperature"]], |
| "optimal_threshold_f1": [result["threshold"]], |
| } |
| ).to_csv( |
| os.path.join(args.OUTPUT_DIR, "composition_lr_validation_diagnostics_summary.csv"), |
| index=False, |
| ) |
| pd.DataFrame({"temperature": [result["temperature"]]}).to_csv( |
| os.path.join(args.OUTPUT_DIR, "composition_lr_validation_temperature.csv"), |
| index=False, |
| ) |
|
|
| model_path = os.path.join(args.OUTPUT_DIR, "composition_lr_model.joblib") |
| joblib.dump( |
| { |
| "clf": result["clf"], |
| "scaler": result["scaler"], |
| "feature_names": FEATURE_NAMES, |
| "C": args.C, |
| }, |
| model_path, |
| ) |
| print(f"Inference model saved: {model_path}") |
|
|
| if args.SAVE_DIAGNOSTICS: |
| fpr, tpr, thresholds = roc_curve(result["y_test"], result["y_test_proba"]) |
| pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": thresholds}).to_csv( |
| os.path.join(args.OUTPUT_DIR, "composition_lr_roc_curve.csv"), index=False |
| ) |
| save_roc_plot( |
| result["y_test"], |
| result["y_test_proba"], |
| test_auc, |
| os.path.join(args.OUTPUT_DIR, "composition_lr_roc_curve.png"), |
| "Composition LR (test)", |
| ) |
|
|
| ablation_rows = [] |
| if args.ABLATION: |
| print("\n" + "-" * 80) |
| print("Ablation: test ROC-AUC by feature group") |
| print("-" * 80) |
| for group_name, cols in FEATURE_GROUPS.items(): |
| col_idx = [FEATURE_NAMES.index(c) for c in cols] |
| X_sub = X[:, col_idx] |
| sub_result = train_eval_composition_model(X_sub, y, cols, C=args.C, seed=seed) |
| auc_sub = sub_result["test_metrics"]["roc_auc"] |
| print(f" {group_name}: AUC = {auc_sub:.4f} ({', '.join(cols)})") |
| ablation_rows.append( |
| { |
| "feature_group": group_name, |
| "features": ";".join(cols), |
| "test_roc_auc": auc_sub, |
| "val_roc_auc": sub_result["val_metrics"]["roc_auc"], |
| } |
| ) |
| pd.DataFrame(ablation_rows).to_csv( |
| os.path.join(args.OUTPUT_DIR, "composition_baseline_ablation.csv"), index=False |
| ) |
|
|
| meta = { |
| "model": "sklearn LogisticRegression (L2)", |
| "features": FEATURE_NAMES, |
| "feature_groups": FEATURE_GROUPS, |
| "split": "80% train / 20% test; 20% of train for validation calibration", |
| "test_roc_auc": test_auc, |
| "interpretation_threshold": 0.85, |
| "csv_id_col": id_col, |
| "csv_label_col": label_col, |
| } |
| with open(os.path.join(args.OUTPUT_DIR, "composition_baseline_metadata.json"), "w") as f: |
| json.dump(meta, f, indent=2) |
|
|
| print(f"\nArtifacts written to '{args.OUTPUT_DIR}/':") |
| print(" - composition_baseline_summary.csv (primary AUC + interpretation flag)") |
| print(" - composition_lr_model.joblib (clf + scaler for evaluate_lr_sequence_composition.py)") |
| print(" - composition_lr_auc_score.csv") |
| print(" - composition_lr_coefficients.csv") |
| print(" - composition_lr_permutation_importance.csv") |
| print(" - composition_lr_validation_diagnostics_summary.csv") |
| print(" - composition_lr_validation_temperature.csv") |
| if args.ABLATION: |
| print(" - composition_baseline_ablation.csv") |
| if args.SAVE_FEATURES or args.SAVE_DIAGNOSTICS: |
| print(" - composition_features_per_variant.csv") |
| if args.SAVE_DIAGNOSTICS: |
| print(" - composition_lr_roc_curve.csv / .png") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|