| """Evaluate FASTA with a trained composition-features logistic regression (optional labeled CSV).""" |
| from __future__ import annotations |
|
|
| import argparse |
| import math |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import joblib |
| import numpy as np |
| import pandas as pd |
|
|
| |
| |
| |
| |
| |
| |
| _REPO_ROOT = Path(__file__).resolve().parents[1] |
| MODEL_CHECKPOINT_DIR = _REPO_ROOT / 'weights' / 'composition_lr' |
| |
|
|
| _MODELS_DIR = _REPO_ROOT / 'models' |
| if str(_MODELS_DIR) not in sys.path: |
| sys.path.insert(0, str(_MODELS_DIR)) |
|
|
| from lr_sequence_composition_baseline import ( |
| FEATURE_NAMES, |
| compute_composition_features, |
| predict_proba, |
| standard_amino_acids, |
| ) |
|
|
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| if _SCRIPT_DIR not in sys.path: |
| sys.path.insert(0, _SCRIPT_DIR) |
|
|
| import eval_metrics as em |
|
|
|
|
| def resolve_model_dir(cli_value: str | None) -> str: |
| if cli_value: |
| return os.path.abspath(cli_value) |
| return str(MODEL_CHECKPOINT_DIR.resolve()) |
|
|
|
|
| def load_inference_operating_point(model_dir: str, model_name: str = 'composition_lr') -> tuple[float, float]: |
| temperature = 1.0 |
| threshold = 0.5 |
| temp_path = os.path.join(model_dir, f'{model_name}_validation_temperature.csv') |
| if os.path.isfile(temp_path): |
| try: |
| df = pd.read_csv(temp_path) |
| if 'temperature' in df.columns and len(df) > 0: |
| temperature = float(df['temperature'].iloc[0]) |
| except Exception as e: |
| print(f'Warning: temperature file: {e}') |
| summary_path = os.path.join(model_dir, f'{model_name}_validation_validation_diagnostics_summary.csv') |
| if not os.path.isfile(summary_path): |
| summary_path = os.path.join(model_dir, f'{model_name}_validation_diagnostics_summary.csv') |
| if os.path.isfile(summary_path): |
| try: |
| sdf = pd.read_csv(summary_path) |
| if 'optimal_threshold_f1' in sdf.columns and len(sdf) > 0: |
| threshold = float(sdf['optimal_threshold_f1'].iloc[0]) |
| if 'temperature' in sdf.columns and len(sdf) > 0 and not os.path.isfile(temp_path): |
| temperature = float(sdf['temperature'].iloc[0]) |
| except Exception as e: |
| print(f'Warning: diagnostics summary: {e}') |
| return temperature, threshold |
|
|
|
|
| def load_composition_model(model_dir: str): |
| model_path = os.path.join(model_dir, 'composition_lr_model.joblib') |
| if not os.path.isfile(model_path): |
| raise FileNotFoundError( |
| f'{model_path} not found. Re-run lr_sequence_composition_baseline.py to export ' |
| 'composition_lr_model.joblib, or pass --model-dir to a directory that contains it.' |
| ) |
| bundle = joblib.load(model_path) |
| clf = bundle['clf'] |
| scaler = bundle['scaler'] |
| feature_names = list(bundle.get('feature_names', FEATURE_NAMES)) |
| if feature_names != FEATURE_NAMES: |
| print( |
| f'Warning: checkpoint feature_names differ from current script; ' |
| f'using checkpoint order ({len(feature_names)} features).' |
| ) |
| return clf, scaler, feature_names |
|
|
|
|
| def parse_arguments(): |
| p = argparse.ArgumentParser( |
| description='Evaluate FASTA with composition LR (gravy, pI, charge, complexity).', |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
| p.add_argument('--fasta', required=True, help='Input FASTA file') |
| p.add_argument( |
| '--model-dir', |
| default=None, |
| help='Overrides MODEL_CHECKPOINT_DIR defined at the top of this script', |
| ) |
| p.add_argument('--output', required=True, help='Output directory') |
| p.add_argument( |
| '--csv', |
| default=None, |
| help='Optional CSV with id column (variant/sequence_id/id) and label (highFRET/lowFRET)', |
| ) |
| p.add_argument( |
| '--no-detailed-metrics', |
| action='store_true', |
| help='With --csv: one-line summary only; skip diagnostic CSV/report output', |
| ) |
| p.add_argument( |
| '--include-features', |
| action='store_true', |
| help='Add per-sequence composition feature columns to the predictions CSV', |
| ) |
| return p.parse_args() |
|
|
|
|
| def read_fasta_robust(fasta_path): |
| h = None |
| buf = [] |
| with open(fasta_path, 'r') as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| if line.startswith('>'): |
| if h is not None: |
| s = ''.join(buf) |
| if s: |
| yield (h, s) |
| h = line[1:].strip() |
| buf = [] |
| elif h is not None: |
| buf.append(line) |
| if h is not None: |
| s = ''.join(buf) |
| if s: |
| yield (h, s) |
|
|
|
|
| def main(): |
| args = parse_arguments() |
| model_dir = resolve_model_dir(args.model_dir) |
| detailed = args.csv and not args.no_detailed_metrics |
| os.makedirs(args.output, exist_ok=True) |
|
|
| print(f'Checkpoint directory: {model_dir}') |
| clf, scaler, feature_names = load_composition_model(model_dir) |
| temperature, threshold = load_inference_operating_point(model_dir) |
| print(f'Operating point: temperature={temperature:.4f}, threshold={threshold:.4f}') |
|
|
| sequence_ids: list[str] = [] |
| sequences: list[str] = [] |
| feature_rows: list[dict[str, float]] = [] |
| skipped: list[tuple[str, int, str]] = [] |
|
|
| for header, seq in read_fasta_robust(args.fasta): |
| feats = compute_composition_features(seq) |
| row = {name: feats[name] for name in feature_names} |
| if any(math.isnan(v) for v in row.values()): |
| skipped.append((header, len(standard_amino_acids(seq)), 'nan_composition_features')) |
| continue |
| sequence_ids.append(header) |
| sequences.append(seq) |
| feature_rows.append(row) |
|
|
| for header, length, reason in skipped: |
| print(f'Warning: skip {header} ({reason}, canonical_aa_len={length})') |
|
|
| if not sequences: |
| print('ERROR: No sequences to process.') |
| return |
|
|
| X = np.array([[row[n] for n in feature_names] for row in feature_rows], dtype=float) |
| proba = predict_proba(clf, scaler, X) |
| proba = em.apply_temperature_scaling(proba, temperature) |
|
|
| results = pd.DataFrame( |
| { |
| 'sequence_id': sequence_ids, |
| 'sequence_length': [len(standard_amino_acids(s)) for s in sequences], |
| 'prediction_probability': proba, |
| 'predicted_class': np.where(proba >= threshold, 'highFRET', 'lowFRET'), |
| } |
| ) |
| if args.include_features: |
| feat_df = pd.DataFrame(feature_rows) |
| results = pd.concat([results, feat_df], axis=1) |
|
|
| out_predictions = os.path.join(args.output, 'composition_lr_predictions.csv') |
| results.to_csv(out_predictions, index=False) |
| print(f'Predictions saved: {out_predictions}') |
|
|
| if skipped: |
| pd.DataFrame(skipped, columns=['sequence_id', 'length', 'reason']).to_csv( |
| os.path.join(args.output, 'skipped_sequences.csv'), index=False |
| ) |
|
|
| if not args.csv or not os.path.isfile(args.csv): |
| print('No labeled CSV; predictions only.') |
| return |
|
|
| try: |
| _, _, label_map = em.build_label_map_int(pd.read_csv(args.csv)) |
| except ValueError as e: |
| print(f'Skipping metrics: {e}') |
| return |
|
|
| y_true_list: list[int] = [] |
| y_prob_list: list[float] = [] |
| for sid in sequence_ids: |
| lab = label_map.get(str(sid).strip()) |
| if lab is None: |
| continue |
| y_true_list.append(lab) |
| y_prob_list.append( |
| float(results.loc[results['sequence_id'] == sid, 'prediction_probability'].iloc[0]) |
| ) |
| if not y_true_list: |
| print('No overlapping labels between FASTA IDs and CSV. Skipping metrics.') |
| return |
|
|
| y_true = np.asarray(y_true_list) |
| y_prob = np.asarray(y_prob_list) |
| y_pred = (y_prob >= threshold).astype(int) |
| em.run_supervised_evaluation( |
| y_true, y_prob, y_pred, 'composition_lr', args.output, threshold, detailed_metrics=detailed |
| ) |
| if detailed: |
| print(f'Detailed diagnostics written under {args.output}') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|