"""Evaluate FASTA with a trained composition-features logistic regression (optional labeled CSV).""" from __future__ import annotations import argparse import math import os import sys from pathlib import Path import joblib import numpy as np import pandas as pd # # ===================================================================== # Default checkpoint: training output dir or repository `weights/composition_lr/` # (see evaluation/README.md). Must contain composition_lr_model.joblib (+ calibration CSVs). # Override via --model-dir. # ===================================================================== _REPO_ROOT = Path(__file__).resolve().parents[1] MODEL_CHECKPOINT_DIR = _REPO_ROOT / 'weights' / 'composition_lr' # _MODELS_DIR = _REPO_ROOT / 'models' if str(_MODELS_DIR) not in sys.path: sys.path.insert(0, str(_MODELS_DIR)) from lr_sequence_composition_baseline import ( # noqa: E402 FEATURE_NAMES, compute_composition_features, predict_proba, standard_amino_acids, ) _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) if _SCRIPT_DIR not in sys.path: sys.path.insert(0, _SCRIPT_DIR) import eval_metrics as em # noqa: E402 def resolve_model_dir(cli_value: str | None) -> str: if cli_value: return os.path.abspath(cli_value) return str(MODEL_CHECKPOINT_DIR.resolve()) def load_inference_operating_point(model_dir: str, model_name: str = 'composition_lr') -> tuple[float, float]: temperature = 1.0 threshold = 0.5 temp_path = os.path.join(model_dir, f'{model_name}_validation_temperature.csv') if os.path.isfile(temp_path): try: df = pd.read_csv(temp_path) if 'temperature' in df.columns and len(df) > 0: temperature = float(df['temperature'].iloc[0]) except Exception as e: print(f'Warning: temperature file: {e}') summary_path = os.path.join(model_dir, f'{model_name}_validation_validation_diagnostics_summary.csv') if not os.path.isfile(summary_path): summary_path = os.path.join(model_dir, f'{model_name}_validation_diagnostics_summary.csv') if os.path.isfile(summary_path): try: sdf = pd.read_csv(summary_path) if 'optimal_threshold_f1' in sdf.columns and len(sdf) > 0: threshold = float(sdf['optimal_threshold_f1'].iloc[0]) if 'temperature' in sdf.columns and len(sdf) > 0 and not os.path.isfile(temp_path): temperature = float(sdf['temperature'].iloc[0]) except Exception as e: print(f'Warning: diagnostics summary: {e}') return temperature, threshold def load_composition_model(model_dir: str): model_path = os.path.join(model_dir, 'composition_lr_model.joblib') if not os.path.isfile(model_path): raise FileNotFoundError( f'{model_path} not found. Re-run lr_sequence_composition_baseline.py to export ' 'composition_lr_model.joblib, or pass --model-dir to a directory that contains it.' ) bundle = joblib.load(model_path) clf = bundle['clf'] scaler = bundle['scaler'] feature_names = list(bundle.get('feature_names', FEATURE_NAMES)) if feature_names != FEATURE_NAMES: print( f'Warning: checkpoint feature_names differ from current script; ' f'using checkpoint order ({len(feature_names)} features).' ) return clf, scaler, feature_names def parse_arguments(): p = argparse.ArgumentParser( description='Evaluate FASTA with composition LR (gravy, pI, charge, complexity).', formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument('--fasta', required=True, help='Input FASTA file') p.add_argument( '--model-dir', default=None, help='Overrides MODEL_CHECKPOINT_DIR defined at the top of this script', ) p.add_argument('--output', required=True, help='Output directory') p.add_argument( '--csv', default=None, help='Optional CSV with id column (variant/sequence_id/id) and label (highFRET/lowFRET)', ) p.add_argument( '--no-detailed-metrics', action='store_true', help='With --csv: one-line summary only; skip diagnostic CSV/report output', ) p.add_argument( '--include-features', action='store_true', help='Add per-sequence composition feature columns to the predictions CSV', ) return p.parse_args() def read_fasta_robust(fasta_path): h = None buf = [] with open(fasta_path, 'r') as f: for line in f: line = line.strip() if not line: continue if line.startswith('>'): if h is not None: s = ''.join(buf) if s: yield (h, s) h = line[1:].strip() buf = [] elif h is not None: buf.append(line) if h is not None: s = ''.join(buf) if s: yield (h, s) def main(): args = parse_arguments() model_dir = resolve_model_dir(args.model_dir) detailed = args.csv and not args.no_detailed_metrics os.makedirs(args.output, exist_ok=True) print(f'Checkpoint directory: {model_dir}') clf, scaler, feature_names = load_composition_model(model_dir) temperature, threshold = load_inference_operating_point(model_dir) print(f'Operating point: temperature={temperature:.4f}, threshold={threshold:.4f}') sequence_ids: list[str] = [] sequences: list[str] = [] feature_rows: list[dict[str, float]] = [] skipped: list[tuple[str, int, str]] = [] for header, seq in read_fasta_robust(args.fasta): feats = compute_composition_features(seq) row = {name: feats[name] for name in feature_names} if any(math.isnan(v) for v in row.values()): skipped.append((header, len(standard_amino_acids(seq)), 'nan_composition_features')) continue sequence_ids.append(header) sequences.append(seq) feature_rows.append(row) for header, length, reason in skipped: print(f'Warning: skip {header} ({reason}, canonical_aa_len={length})') if not sequences: print('ERROR: No sequences to process.') return X = np.array([[row[n] for n in feature_names] for row in feature_rows], dtype=float) proba = predict_proba(clf, scaler, X) proba = em.apply_temperature_scaling(proba, temperature) results = pd.DataFrame( { 'sequence_id': sequence_ids, 'sequence_length': [len(standard_amino_acids(s)) for s in sequences], 'prediction_probability': proba, 'predicted_class': np.where(proba >= threshold, 'highFRET', 'lowFRET'), } ) if args.include_features: feat_df = pd.DataFrame(feature_rows) results = pd.concat([results, feat_df], axis=1) out_predictions = os.path.join(args.output, 'composition_lr_predictions.csv') results.to_csv(out_predictions, index=False) print(f'Predictions saved: {out_predictions}') if skipped: pd.DataFrame(skipped, columns=['sequence_id', 'length', 'reason']).to_csv( os.path.join(args.output, 'skipped_sequences.csv'), index=False ) if not args.csv or not os.path.isfile(args.csv): print('No labeled CSV; predictions only.') return try: _, _, label_map = em.build_label_map_int(pd.read_csv(args.csv)) except ValueError as e: print(f'Skipping metrics: {e}') return y_true_list: list[int] = [] y_prob_list: list[float] = [] for sid in sequence_ids: lab = label_map.get(str(sid).strip()) if lab is None: continue y_true_list.append(lab) y_prob_list.append( float(results.loc[results['sequence_id'] == sid, 'prediction_probability'].iloc[0]) ) if not y_true_list: print('No overlapping labels between FASTA IDs and CSV. Skipping metrics.') return y_true = np.asarray(y_true_list) y_prob = np.asarray(y_prob_list) y_pred = (y_prob >= threshold).astype(int) em.run_supervised_evaluation( y_true, y_prob, y_pred, 'composition_lr', args.output, threshold, detailed_metrics=detailed ) if detailed: print(f'Detailed diagnostics written under {args.output}') if __name__ == '__main__': main()