protloc-ai / scripts /test_relocalizer.py
Tanoj22
Initial commit: ProtLoc-AI project setup and core app
cb6f1ba
"""
Smoke test for ProteinRelocalizer using multiple source->target pairs.
The script tries candidate shifts:
- Cytoplasm -> Extracellular
- Cytoplasm -> Membrane
- Cytoplasm -> Nucleus
For each pair, it picks a protein in a configurable source-probability band and
low target probability, then runs a short probe optimization to estimate movement.
It chooses the pair with the clearest target-probability gain and runs the full
demo optimization.
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from typing import Any, Dict, List, Tuple
import pandas as pd
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
class _QuietTqdm:
"""Disable progress bar noise for smoke tests."""
def __init__(self, iterable, *args, **kwargs):
self._it = iterable
def __iter__(self):
return iter(self._it)
def set_postfix(self, *args, **kwargs) -> None:
pass
def close(self) -> None:
pass
import tqdm as _tqdm_mod # noqa: E402
_tqdm_mod.tqdm = _QuietTqdm # noqa: E402
from src.design.relocalizer import ProteinRelocalizer # noqa: E402
def _pick_test_protein_for_pair(
df: pd.DataFrame,
relocalizer: ProteinRelocalizer,
*,
source: str,
target: str,
min_seq_len: int,
max_seq_len: int,
source_prob_min: float,
source_prob_max: float,
target_prob_max: float,
max_scan: int,
) -> Tuple[str, str, Dict[str, Any]]:
"""
From CSV rows labeled by `source`, pick a sequence in [min_seq_len, max_seq_len]
whose classifier probabilities satisfy:
source_prob_min <= P(source) <= source_prob_max
P(target) < target_prob_max
Rows are ordered **longest first** (within the length band).
"""
required = {"ACC", "Sequence", source, target}
missing = required.difference(df.columns)
if missing:
raise ValueError(f"Dataset missing columns: {sorted(missing)}")
pool = df[df[source] == 1].copy()
if pool.empty:
raise ValueError(f"No rows with {source} label == 1 in CSV.")
pool["seq_len"] = pool["Sequence"].astype(str).str.len()
pool = pool[(pool["seq_len"] >= min_seq_len) & (pool["seq_len"] <= max_seq_len)]
pool = pool.sort_values("seq_len", ascending=False)
if pool.empty:
raise ValueError(
f"No Cytoplasm-positive rows with {min_seq_len} <= length <= {max_seq_len}."
)
tried = 0
for _, row in pool.iterrows():
if tried >= max_scan:
break
tried += 1
seq = str(row["Sequence"]).upper().strip()
if len(seq) < min_seq_len or len(seq) > max_seq_len:
continue
scored = relocalizer.score_variant(seq)
probs = scored["localization_probs"]
p_src = float(probs.get(source, 0.0))
p_tgt = float(probs.get(target, 0.0))
if source_prob_min <= p_src <= source_prob_max and p_tgt < target_prob_max:
return str(row["ACC"]), seq, scored
raise RuntimeError(
f"No sequence found after scoring {tried} {source}-labeled proteins "
f"(len in [{min_seq_len}, {max_seq_len}], need {source_prob_min}<=P({source})<={source_prob_max} "
f"and P({target})<{target_prob_max}). "
"Try increasing --max-scan or relaxing probability/length bounds."
)
def _print_trajectory(traj: list, source: str, target: str) -> None:
print("\n--- Optimization trajectory ---")
if not traj:
print("(empty)")
return
for step in traj:
it = step.get("iteration", "?")
p_tgt = step.get("best_target_prob")
p_src = step.get("best_source_prob")
nmut = step.get("num_mutations")
neval = step.get("num_candidates_evaluated")
err = step.get("error")
parts = [f"iter={it}"]
if p_tgt is not None:
parts.append(f"P({target})={p_tgt:.4f}")
if p_src is not None:
parts.append(f"P({source})={p_src:.4f}")
if nmut is not None:
parts.append(f"muts={nmut}")
if neval is not None:
parts.append(f"evaluated={neval}")
if err:
parts.append(f"error={err}")
print(" " + " | ".join(parts))
def _print_top_candidate(
results: Dict[str, Any],
source: str,
target: str,
) -> None:
tops = results.get("top_candidates") or []
print("\n--- Top candidate ---")
if not tops:
print("(none)")
return
best = tops[0]
orig = results["original_scores"]["localization_probs"]
newp = best["localization_probs"]
print(f"Composite score: {best['composite_score']:.4f}")
print(f"Plausibility: {best['plausibility_score']:.4f}")
print("")
print("Probability changes (original -> candidate):")
for loc in (source, target):
o = float(orig.get(loc, 0.0))
n = float(newp.get(loc, 0.0))
print(f" P({loc}): {o:.4f} -> {n:.4f} (delta {n - o:+.4f})")
muts = best.get("mutations") or []
print("")
print(f"Mutations vs original ({len(muts)}):")
if muts:
shown = muts[:40]
line = ", ".join(f"{p}{a}>{b}" for p, a, b in shown)
if len(muts) > 40:
line += ", ..."
print(f" {line}")
else:
print(" (none; same as original)")
def _print_all_location_changes(results: Dict[str, Any]) -> None:
print("\n--- Full localization profile shift (all labels) ---")
tops = results.get("top_candidates") or []
if not tops:
print("(no top candidate)")
return
orig = results.get("original_scores", {}).get("localization_probs", {})
best = tops[0].get("localization_probs", {})
labels = sorted(set(orig.keys()) | set(best.keys()))
for label in labels:
o = float(orig.get(label, 0.0))
n = float(best.get(label, 0.0))
print(f" {label:24s} {o:.4f} -> {n:.4f} (delta {n - o:+.4f})")
def _movement_score(results: Dict[str, Any], source: str, target: str) -> float:
"""Positive score means shift toward target and away from source."""
orig = results.get("original_scores", {}).get("localization_probs", {})
tops = results.get("top_candidates") or []
if not tops:
return float("-inf")
best = tops[0].get("localization_probs", {})
delta_t = float(best.get(target, 0.0)) - float(orig.get(target, 0.0))
delta_s = float(orig.get(source, 0.0)) - float(best.get(source, 0.0))
return delta_t + 0.5 * delta_s
def main() -> None:
try:
import captum # noqa: F401
except ImportError:
print(
"ERROR: captum is required for relocalize() (integrated gradients). "
"Install with: pip install captum",
file=sys.stderr,
)
sys.exit(2)
p = argparse.ArgumentParser(description="Smoke test ProteinRelocalizer (lightweight, fast).")
p.add_argument(
"--classifier-path",
type=Path,
default=ROOT / "models" / "best_model.pt",
help="Trained classifier checkpoint.",
)
p.add_argument(
"--csv-path",
type=Path,
default=ROOT / "data" / "processed" / "deeploc_multilabel.csv",
help="Multilabel CSV with ACC, Sequence, Cytoplasm, ...",
)
p.add_argument("--device", default=None, help="cuda | cpu (default: auto)")
p.add_argument(
"--min-seq-len",
type=int,
default=150,
help="Minimum sequence length for the test protein (default: 150).",
)
p.add_argument(
"--max-seq-len",
type=int,
default=350,
help="Maximum sequence length for the test protein (default: 350).",
)
p.add_argument(
"--source-prob-min",
type=float,
default=0.5,
help="Minimum P(source) for selected test protein (default: 0.5).",
)
p.add_argument(
"--source-prob-max",
type=float,
default=0.9,
help="Maximum P(source) for selected test protein (default: 0.9).",
)
p.add_argument(
"--target-prob-max",
type=float,
default=0.25,
help="Require P(target) below this for selected protein (default: 0.25).",
)
p.add_argument(
"--max-scan",
type=int,
default=500,
help="Max CSV rows (after length filter) to score when searching for a match.",
)
p.add_argument(
"--probe-iterations",
type=int,
default=3,
help="Short probe iterations for pair selection (default: 3).",
)
p.add_argument(
"--probe-candidates",
type=int,
default=10,
help="Short probe candidates/iter for pair selection (default: 10).",
)
args = p.parse_args()
classifier_path = args.classifier_path
if not classifier_path.is_absolute():
classifier_path = (ROOT / classifier_path).resolve()
csv_path = args.csv_path
if not csv_path.is_absolute():
csv_path = (ROOT / csv_path).resolve()
if not classifier_path.is_file():
print(f"ERROR: missing classifier: {classifier_path}", file=sys.stderr)
sys.exit(1)
if not csv_path.is_file():
print(f"ERROR: missing dataset CSV: {csv_path}", file=sys.stderr)
sys.exit(1)
print("Loading ProteinRelocalizer.from_lightweight() (t33 encoder + t12 MLM)...", flush=True)
device = args.device
relocalizer = ProteinRelocalizer.from_lightweight(
classifier_path=classifier_path,
device=device,
)
if args.min_seq_len > args.max_seq_len:
print("ERROR: --min-seq-len must be <= --max-seq-len", file=sys.stderr)
sys.exit(1)
if args.source_prob_min > args.source_prob_max:
print("ERROR: --source-prob-min must be <= --source-prob-max", file=sys.stderr)
sys.exit(1)
df = pd.read_csv(csv_path)
pair_candidates: List[Tuple[str, str]] = [
("Cytoplasm", "Extracellular"),
("Cytoplasm", "Membrane"),
("Cytoplasm", "Nucleus"),
]
pair_trials: List[Dict[str, Any]] = []
print("\nTrying source->target pairs to find a clear demo shift...", flush=True)
for source, target in pair_candidates:
print(f"\n[Pair probe] {source} -> {target}", flush=True)
try:
acc, sequence, initial = _pick_test_protein_for_pair(
df,
relocalizer,
source=source,
target=target,
min_seq_len=args.min_seq_len,
max_seq_len=args.max_seq_len,
source_prob_min=args.source_prob_min,
source_prob_max=args.source_prob_max,
target_prob_max=args.target_prob_max,
max_scan=args.max_scan,
)
probe = relocalizer.relocalize(
sequence,
source_location=source,
target_location=target,
n_iterations=args.probe_iterations,
candidates_per_iteration=args.probe_candidates,
)
move = _movement_score(probe, source, target)
print(
f" ACC={acc} len={len(sequence)} "
f"P({source})={initial['localization_probs'].get(source, 0):.4f} "
f"P({target})={initial['localization_probs'].get(target, 0):.4f} "
f"probe_movement={move:+.4f}",
flush=True,
)
pair_trials.append(
{
"source": source,
"target": target,
"acc": acc,
"sequence": sequence,
"initial": initial,
"probe_results": probe,
"movement": move,
}
)
except Exception as ex:
print(f" skipped: {ex}", flush=True)
if not pair_trials:
raise RuntimeError("No valid pair/test-protein found for the configured constraints.")
pair_trials.sort(key=lambda x: float(x["movement"]), reverse=True)
chosen = pair_trials[0]
source = str(chosen["source"])
target = str(chosen["target"])
acc = str(chosen["acc"])
sequence = str(chosen["sequence"])
initial = chosen["initial"]
print(
f"\nSelected demo pair: {source} -> {target} "
f"(probe movement {float(chosen['movement']):+.4f})",
flush=True,
)
print(
f"Selected ACC={acc}, length={len(sequence)} "
f"P({source})={initial['localization_probs'].get(source, 0):.4f} "
f"P({target})={initial['localization_probs'].get(target, 0):.4f}",
flush=True,
)
print("\nRunning full relocalize (10 iterations, 20 candidates/iter)...", flush=True)
results = relocalizer.relocalize(
sequence,
source_location=source,
target_location=target,
n_iterations=10,
candidates_per_iteration=20,
)
_print_trajectory(results.get("optimization_trajectory") or [], source, target)
_print_top_candidate(results, source, target)
_print_all_location_changes(results)
print("\n--- Summary ---")
print(relocalizer.get_summary(results))
print(f"\nTotal time: {results.get('total_time_seconds', 0):.2f}s "
f"variants scored: {results.get('total_variants_evaluated', 0)}")
if __name__ == "__main__":
main()