""" Smoke test for ProteinRelocalizer using multiple source->target pairs. The script tries candidate shifts: - Cytoplasm -> Extracellular - Cytoplasm -> Membrane - Cytoplasm -> Nucleus For each pair, it picks a protein in a configurable source-probability band and low target probability, then runs a short probe optimization to estimate movement. It chooses the pair with the clearest target-probability gain and runs the full demo optimization. """ from __future__ import annotations import argparse import sys from pathlib import Path from typing import Any, Dict, List, Tuple import pandas as pd ROOT = Path(__file__).resolve().parent.parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) class _QuietTqdm: """Disable progress bar noise for smoke tests.""" def __init__(self, iterable, *args, **kwargs): self._it = iterable def __iter__(self): return iter(self._it) def set_postfix(self, *args, **kwargs) -> None: pass def close(self) -> None: pass import tqdm as _tqdm_mod # noqa: E402 _tqdm_mod.tqdm = _QuietTqdm # noqa: E402 from src.design.relocalizer import ProteinRelocalizer # noqa: E402 def _pick_test_protein_for_pair( df: pd.DataFrame, relocalizer: ProteinRelocalizer, *, source: str, target: str, min_seq_len: int, max_seq_len: int, source_prob_min: float, source_prob_max: float, target_prob_max: float, max_scan: int, ) -> Tuple[str, str, Dict[str, Any]]: """ From CSV rows labeled by `source`, pick a sequence in [min_seq_len, max_seq_len] whose classifier probabilities satisfy: source_prob_min <= P(source) <= source_prob_max P(target) < target_prob_max Rows are ordered **longest first** (within the length band). """ required = {"ACC", "Sequence", source, target} missing = required.difference(df.columns) if missing: raise ValueError(f"Dataset missing columns: {sorted(missing)}") pool = df[df[source] == 1].copy() if pool.empty: raise ValueError(f"No rows with {source} label == 1 in CSV.") pool["seq_len"] = pool["Sequence"].astype(str).str.len() pool = pool[(pool["seq_len"] >= min_seq_len) & (pool["seq_len"] <= max_seq_len)] pool = pool.sort_values("seq_len", ascending=False) if pool.empty: raise ValueError( f"No Cytoplasm-positive rows with {min_seq_len} <= length <= {max_seq_len}." ) tried = 0 for _, row in pool.iterrows(): if tried >= max_scan: break tried += 1 seq = str(row["Sequence"]).upper().strip() if len(seq) < min_seq_len or len(seq) > max_seq_len: continue scored = relocalizer.score_variant(seq) probs = scored["localization_probs"] p_src = float(probs.get(source, 0.0)) p_tgt = float(probs.get(target, 0.0)) if source_prob_min <= p_src <= source_prob_max and p_tgt < target_prob_max: return str(row["ACC"]), seq, scored raise RuntimeError( f"No sequence found after scoring {tried} {source}-labeled proteins " f"(len in [{min_seq_len}, {max_seq_len}], need {source_prob_min}<=P({source})<={source_prob_max} " f"and P({target})<{target_prob_max}). " "Try increasing --max-scan or relaxing probability/length bounds." ) def _print_trajectory(traj: list, source: str, target: str) -> None: print("\n--- Optimization trajectory ---") if not traj: print("(empty)") return for step in traj: it = step.get("iteration", "?") p_tgt = step.get("best_target_prob") p_src = step.get("best_source_prob") nmut = step.get("num_mutations") neval = step.get("num_candidates_evaluated") err = step.get("error") parts = [f"iter={it}"] if p_tgt is not None: parts.append(f"P({target})={p_tgt:.4f}") if p_src is not None: parts.append(f"P({source})={p_src:.4f}") if nmut is not None: parts.append(f"muts={nmut}") if neval is not None: parts.append(f"evaluated={neval}") if err: parts.append(f"error={err}") print(" " + " | ".join(parts)) def _print_top_candidate( results: Dict[str, Any], source: str, target: str, ) -> None: tops = results.get("top_candidates") or [] print("\n--- Top candidate ---") if not tops: print("(none)") return best = tops[0] orig = results["original_scores"]["localization_probs"] newp = best["localization_probs"] print(f"Composite score: {best['composite_score']:.4f}") print(f"Plausibility: {best['plausibility_score']:.4f}") print("") print("Probability changes (original -> candidate):") for loc in (source, target): o = float(orig.get(loc, 0.0)) n = float(newp.get(loc, 0.0)) print(f" P({loc}): {o:.4f} -> {n:.4f} (delta {n - o:+.4f})") muts = best.get("mutations") or [] print("") print(f"Mutations vs original ({len(muts)}):") if muts: shown = muts[:40] line = ", ".join(f"{p}{a}>{b}" for p, a, b in shown) if len(muts) > 40: line += ", ..." print(f" {line}") else: print(" (none; same as original)") def _print_all_location_changes(results: Dict[str, Any]) -> None: print("\n--- Full localization profile shift (all labels) ---") tops = results.get("top_candidates") or [] if not tops: print("(no top candidate)") return orig = results.get("original_scores", {}).get("localization_probs", {}) best = tops[0].get("localization_probs", {}) labels = sorted(set(orig.keys()) | set(best.keys())) for label in labels: o = float(orig.get(label, 0.0)) n = float(best.get(label, 0.0)) print(f" {label:24s} {o:.4f} -> {n:.4f} (delta {n - o:+.4f})") def _movement_score(results: Dict[str, Any], source: str, target: str) -> float: """Positive score means shift toward target and away from source.""" orig = results.get("original_scores", {}).get("localization_probs", {}) tops = results.get("top_candidates") or [] if not tops: return float("-inf") best = tops[0].get("localization_probs", {}) delta_t = float(best.get(target, 0.0)) - float(orig.get(target, 0.0)) delta_s = float(orig.get(source, 0.0)) - float(best.get(source, 0.0)) return delta_t + 0.5 * delta_s def main() -> None: try: import captum # noqa: F401 except ImportError: print( "ERROR: captum is required for relocalize() (integrated gradients). " "Install with: pip install captum", file=sys.stderr, ) sys.exit(2) p = argparse.ArgumentParser(description="Smoke test ProteinRelocalizer (lightweight, fast).") p.add_argument( "--classifier-path", type=Path, default=ROOT / "models" / "best_model.pt", help="Trained classifier checkpoint.", ) p.add_argument( "--csv-path", type=Path, default=ROOT / "data" / "processed" / "deeploc_multilabel.csv", help="Multilabel CSV with ACC, Sequence, Cytoplasm, ...", ) p.add_argument("--device", default=None, help="cuda | cpu (default: auto)") p.add_argument( "--min-seq-len", type=int, default=150, help="Minimum sequence length for the test protein (default: 150).", ) p.add_argument( "--max-seq-len", type=int, default=350, help="Maximum sequence length for the test protein (default: 350).", ) p.add_argument( "--source-prob-min", type=float, default=0.5, help="Minimum P(source) for selected test protein (default: 0.5).", ) p.add_argument( "--source-prob-max", type=float, default=0.9, help="Maximum P(source) for selected test protein (default: 0.9).", ) p.add_argument( "--target-prob-max", type=float, default=0.25, help="Require P(target) below this for selected protein (default: 0.25).", ) p.add_argument( "--max-scan", type=int, default=500, help="Max CSV rows (after length filter) to score when searching for a match.", ) p.add_argument( "--probe-iterations", type=int, default=3, help="Short probe iterations for pair selection (default: 3).", ) p.add_argument( "--probe-candidates", type=int, default=10, help="Short probe candidates/iter for pair selection (default: 10).", ) args = p.parse_args() classifier_path = args.classifier_path if not classifier_path.is_absolute(): classifier_path = (ROOT / classifier_path).resolve() csv_path = args.csv_path if not csv_path.is_absolute(): csv_path = (ROOT / csv_path).resolve() if not classifier_path.is_file(): print(f"ERROR: missing classifier: {classifier_path}", file=sys.stderr) sys.exit(1) if not csv_path.is_file(): print(f"ERROR: missing dataset CSV: {csv_path}", file=sys.stderr) sys.exit(1) print("Loading ProteinRelocalizer.from_lightweight() (t33 encoder + t12 MLM)...", flush=True) device = args.device relocalizer = ProteinRelocalizer.from_lightweight( classifier_path=classifier_path, device=device, ) if args.min_seq_len > args.max_seq_len: print("ERROR: --min-seq-len must be <= --max-seq-len", file=sys.stderr) sys.exit(1) if args.source_prob_min > args.source_prob_max: print("ERROR: --source-prob-min must be <= --source-prob-max", file=sys.stderr) sys.exit(1) df = pd.read_csv(csv_path) pair_candidates: List[Tuple[str, str]] = [ ("Cytoplasm", "Extracellular"), ("Cytoplasm", "Membrane"), ("Cytoplasm", "Nucleus"), ] pair_trials: List[Dict[str, Any]] = [] print("\nTrying source->target pairs to find a clear demo shift...", flush=True) for source, target in pair_candidates: print(f"\n[Pair probe] {source} -> {target}", flush=True) try: acc, sequence, initial = _pick_test_protein_for_pair( df, relocalizer, source=source, target=target, min_seq_len=args.min_seq_len, max_seq_len=args.max_seq_len, source_prob_min=args.source_prob_min, source_prob_max=args.source_prob_max, target_prob_max=args.target_prob_max, max_scan=args.max_scan, ) probe = relocalizer.relocalize( sequence, source_location=source, target_location=target, n_iterations=args.probe_iterations, candidates_per_iteration=args.probe_candidates, ) move = _movement_score(probe, source, target) print( f" ACC={acc} len={len(sequence)} " f"P({source})={initial['localization_probs'].get(source, 0):.4f} " f"P({target})={initial['localization_probs'].get(target, 0):.4f} " f"probe_movement={move:+.4f}", flush=True, ) pair_trials.append( { "source": source, "target": target, "acc": acc, "sequence": sequence, "initial": initial, "probe_results": probe, "movement": move, } ) except Exception as ex: print(f" skipped: {ex}", flush=True) if not pair_trials: raise RuntimeError("No valid pair/test-protein found for the configured constraints.") pair_trials.sort(key=lambda x: float(x["movement"]), reverse=True) chosen = pair_trials[0] source = str(chosen["source"]) target = str(chosen["target"]) acc = str(chosen["acc"]) sequence = str(chosen["sequence"]) initial = chosen["initial"] print( f"\nSelected demo pair: {source} -> {target} " f"(probe movement {float(chosen['movement']):+.4f})", flush=True, ) print( f"Selected ACC={acc}, length={len(sequence)} " f"P({source})={initial['localization_probs'].get(source, 0):.4f} " f"P({target})={initial['localization_probs'].get(target, 0):.4f}", flush=True, ) print("\nRunning full relocalize (10 iterations, 20 candidates/iter)...", flush=True) results = relocalizer.relocalize( sequence, source_location=source, target_location=target, n_iterations=10, candidates_per_iteration=20, ) _print_trajectory(results.get("optimization_trajectory") or [], source, target) _print_top_candidate(results, source, target) _print_all_location_changes(results) print("\n--- Summary ---") print(relocalizer.get_summary(results)) print(f"\nTotal time: {results.get('total_time_seconds', 0):.2f}s " f"variants scored: {results.get('total_variants_evaluated', 0)}") if __name__ == "__main__": main()