FRET-FACS / evaluation /evaluate_lr_sequence_composition.py
neuwirtt
Initial release: FRET-FACS pipeline, weights, and datasets
6e4d123
Raw
History Blame Contribute Delete
8.5 kB
"""Evaluate FASTA with a trained composition-features logistic regression (optional labeled CSV)."""
from __future__ import annotations
import argparse
import math
import os
import sys
from pathlib import Path
import joblib
import numpy as np
import pandas as pd
#
# =====================================================================
# Default checkpoint: training output dir or repository `weights/composition_lr/`
# (see evaluation/README.md). Must contain composition_lr_model.joblib (+ calibration CSVs).
# Override via --model-dir.
# =====================================================================
_REPO_ROOT = Path(__file__).resolve().parents[1]
MODEL_CHECKPOINT_DIR = _REPO_ROOT / 'weights' / 'composition_lr'
#
_MODELS_DIR = _REPO_ROOT / 'models'
if str(_MODELS_DIR) not in sys.path:
sys.path.insert(0, str(_MODELS_DIR))
from lr_sequence_composition_baseline import ( # noqa: E402
FEATURE_NAMES,
compute_composition_features,
predict_proba,
standard_amino_acids,
)
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
if _SCRIPT_DIR not in sys.path:
sys.path.insert(0, _SCRIPT_DIR)
import eval_metrics as em # noqa: E402
def resolve_model_dir(cli_value: str | None) -> str:
if cli_value:
return os.path.abspath(cli_value)
return str(MODEL_CHECKPOINT_DIR.resolve())
def load_inference_operating_point(model_dir: str, model_name: str = 'composition_lr') -> tuple[float, float]:
temperature = 1.0
threshold = 0.5
temp_path = os.path.join(model_dir, f'{model_name}_validation_temperature.csv')
if os.path.isfile(temp_path):
try:
df = pd.read_csv(temp_path)
if 'temperature' in df.columns and len(df) > 0:
temperature = float(df['temperature'].iloc[0])
except Exception as e:
print(f'Warning: temperature file: {e}')
summary_path = os.path.join(model_dir, f'{model_name}_validation_validation_diagnostics_summary.csv')
if not os.path.isfile(summary_path):
summary_path = os.path.join(model_dir, f'{model_name}_validation_diagnostics_summary.csv')
if os.path.isfile(summary_path):
try:
sdf = pd.read_csv(summary_path)
if 'optimal_threshold_f1' in sdf.columns and len(sdf) > 0:
threshold = float(sdf['optimal_threshold_f1'].iloc[0])
if 'temperature' in sdf.columns and len(sdf) > 0 and not os.path.isfile(temp_path):
temperature = float(sdf['temperature'].iloc[0])
except Exception as e:
print(f'Warning: diagnostics summary: {e}')
return temperature, threshold
def load_composition_model(model_dir: str):
model_path = os.path.join(model_dir, 'composition_lr_model.joblib')
if not os.path.isfile(model_path):
raise FileNotFoundError(
f'{model_path} not found. Re-run lr_sequence_composition_baseline.py to export '
'composition_lr_model.joblib, or pass --model-dir to a directory that contains it.'
)
bundle = joblib.load(model_path)
clf = bundle['clf']
scaler = bundle['scaler']
feature_names = list(bundle.get('feature_names', FEATURE_NAMES))
if feature_names != FEATURE_NAMES:
print(
f'Warning: checkpoint feature_names differ from current script; '
f'using checkpoint order ({len(feature_names)} features).'
)
return clf, scaler, feature_names
def parse_arguments():
p = argparse.ArgumentParser(
description='Evaluate FASTA with composition LR (gravy, pI, charge, complexity).',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument('--fasta', required=True, help='Input FASTA file')
p.add_argument(
'--model-dir',
default=None,
help='Overrides MODEL_CHECKPOINT_DIR defined at the top of this script',
)
p.add_argument('--output', required=True, help='Output directory')
p.add_argument(
'--csv',
default=None,
help='Optional CSV with id column (variant/sequence_id/id) and label (highFRET/lowFRET)',
)
p.add_argument(
'--no-detailed-metrics',
action='store_true',
help='With --csv: one-line summary only; skip diagnostic CSV/report output',
)
p.add_argument(
'--include-features',
action='store_true',
help='Add per-sequence composition feature columns to the predictions CSV',
)
return p.parse_args()
def read_fasta_robust(fasta_path):
h = None
buf = []
with open(fasta_path, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
if line.startswith('>'):
if h is not None:
s = ''.join(buf)
if s:
yield (h, s)
h = line[1:].strip()
buf = []
elif h is not None:
buf.append(line)
if h is not None:
s = ''.join(buf)
if s:
yield (h, s)
def main():
args = parse_arguments()
model_dir = resolve_model_dir(args.model_dir)
detailed = args.csv and not args.no_detailed_metrics
os.makedirs(args.output, exist_ok=True)
print(f'Checkpoint directory: {model_dir}')
clf, scaler, feature_names = load_composition_model(model_dir)
temperature, threshold = load_inference_operating_point(model_dir)
print(f'Operating point: temperature={temperature:.4f}, threshold={threshold:.4f}')
sequence_ids: list[str] = []
sequences: list[str] = []
feature_rows: list[dict[str, float]] = []
skipped: list[tuple[str, int, str]] = []
for header, seq in read_fasta_robust(args.fasta):
feats = compute_composition_features(seq)
row = {name: feats[name] for name in feature_names}
if any(math.isnan(v) for v in row.values()):
skipped.append((header, len(standard_amino_acids(seq)), 'nan_composition_features'))
continue
sequence_ids.append(header)
sequences.append(seq)
feature_rows.append(row)
for header, length, reason in skipped:
print(f'Warning: skip {header} ({reason}, canonical_aa_len={length})')
if not sequences:
print('ERROR: No sequences to process.')
return
X = np.array([[row[n] for n in feature_names] for row in feature_rows], dtype=float)
proba = predict_proba(clf, scaler, X)
proba = em.apply_temperature_scaling(proba, temperature)
results = pd.DataFrame(
{
'sequence_id': sequence_ids,
'sequence_length': [len(standard_amino_acids(s)) for s in sequences],
'prediction_probability': proba,
'predicted_class': np.where(proba >= threshold, 'highFRET', 'lowFRET'),
}
)
if args.include_features:
feat_df = pd.DataFrame(feature_rows)
results = pd.concat([results, feat_df], axis=1)
out_predictions = os.path.join(args.output, 'composition_lr_predictions.csv')
results.to_csv(out_predictions, index=False)
print(f'Predictions saved: {out_predictions}')
if skipped:
pd.DataFrame(skipped, columns=['sequence_id', 'length', 'reason']).to_csv(
os.path.join(args.output, 'skipped_sequences.csv'), index=False
)
if not args.csv or not os.path.isfile(args.csv):
print('No labeled CSV; predictions only.')
return
try:
_, _, label_map = em.build_label_map_int(pd.read_csv(args.csv))
except ValueError as e:
print(f'Skipping metrics: {e}')
return
y_true_list: list[int] = []
y_prob_list: list[float] = []
for sid in sequence_ids:
lab = label_map.get(str(sid).strip())
if lab is None:
continue
y_true_list.append(lab)
y_prob_list.append(
float(results.loc[results['sequence_id'] == sid, 'prediction_probability'].iloc[0])
)
if not y_true_list:
print('No overlapping labels between FASTA IDs and CSV. Skipping metrics.')
return
y_true = np.asarray(y_true_list)
y_prob = np.asarray(y_prob_list)
y_pred = (y_prob >= threshold).astype(int)
em.run_supervised_evaluation(
y_true, y_prob, y_pred, 'composition_lr', args.output, threshold, detailed_metrics=detailed
)
if detailed:
print(f'Detailed diagnostics written under {args.output}')
if __name__ == '__main__':
main()