protloc-ai / scripts /test_variant_effect_v2.py
Tanoj22
Initial commit: ProtLoc-AI project setup and core app
cb6f1ba
"""
Attention-targeted smoke tests for residue-level variant effects.
Test A (Nucleus):
1) find a protein with moderate P(Nucleus) in [0.5, 0.8]
2) pick highest-attention K/R residue
3) mutate it to A
Test B (Membrane short proteins):
1) find a short protein (50-150 aa) with P(Membrane) > 0.5
2) pick highest-attention hydrophobic residue (L/V/I/F)
3) mutate it to D
Also keeps the N-terminal wipeout test for Test A (mutate K/R in positions 1-30 to A).
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import pandas as pd
import torch
from transformers import AutoModel, AutoTokenizer
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from src.models.residue_classifier import FALLBACK_LABEL_NAMES, ResidueLocalizationClassifier # noqa: E402
from src.utils.device import resolve_torch_device # noqa: E402
ESM_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
MAX_LENGTH = 1024
def _load_sequences_from_csv(csv_path: Path) -> pd.DataFrame:
df = pd.read_csv(csv_path)
required = {"ACC", "Sequence"}
missing = required.difference(df.columns)
if missing:
raise ValueError(f"CSV missing required columns: {sorted(missing)}")
if len(df) == 0:
raise ValueError("CSV is empty.")
df = df.copy()
df["ACC"] = df["ACC"].astype(str)
df["Sequence"] = df["Sequence"].astype(str).str.upper().str.strip()
df = df[df["Sequence"].str.len() > 0].reset_index(drop=True)
if len(df) == 0:
raise ValueError("No non-empty sequences in CSV.")
return df
def _load_models(
classifier_path: Path,
device_req: str | None,
) -> Tuple[torch.device, Any, Any, ResidueLocalizationClassifier, List[str]]:
device = resolve_torch_device(device_req)
tokenizer = AutoTokenizer.from_pretrained(ESM_MODEL_NAME)
esm = AutoModel.from_pretrained(
ESM_MODEL_NAME,
attn_implementation="eager",
ignore_mismatched_sizes=True,
)
esm.eval().to(device)
ckpt = torch.load(classifier_path, map_location="cpu")
if not isinstance(ckpt, dict):
raise ValueError("Unsupported residue checkpoint format.")
state = ckpt.get("state_dict", ckpt.get("model_state_dict", ckpt))
embedding_dim = int(ckpt.get("embedding_dim", 1280))
num_labels = int(ckpt.get("num_labels", 11))
label_names = list(ckpt.get("label_names") or FALLBACK_LABEL_NAMES[:num_labels])
if len(label_names) != num_labels:
raise ValueError("Checkpoint label_names length does not match num_labels.")
model = ResidueLocalizationClassifier(
embedding_dim=embedding_dim,
num_labels=num_labels,
label_names=label_names,
dropout=float(ckpt.get("dropout", 0.3)),
num_heads=int(ckpt.get("num_heads", 4)),
)
model.load_state_dict(state, strict=True)
model.eval().to(device)
return device, tokenizer, esm, model, label_names
@torch.inference_mode()
def _embed_residue_sequence(
sequence: str,
tokenizer: Any,
esm: Any,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
seq = sequence.upper().strip()
toks = tokenizer(
[seq],
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_LENGTH,
add_special_tokens=True,
)
toks = {k: v.to(device) for k, v in toks.items()}
out = esm(**toks, return_dict=True)
hidden = out.last_hidden_state
attn = toks["attention_mask"]
valid_len = int(attn[0].sum().item())
if valid_len < 3:
raise ValueError("Tokenized sequence too short for residue embedding extraction.")
core = hidden[0, 1 : valid_len - 1, :].float() # strip special tokens -> (L, 1280)
mask = torch.ones(core.shape[0], dtype=torch.bool, device=device)
return core.unsqueeze(0), mask.unsqueeze(0)
@torch.inference_mode()
def _predict_with_attention(
sequence: str,
tokenizer: Any,
esm: Any,
classifier: ResidueLocalizationClassifier,
label_names: Sequence[str],
device: torch.device,
) -> Tuple[Dict[str, float], List[float]]:
x, mask = _embed_residue_sequence(sequence, tokenizer, esm, device)
logits, attn = classifier.get_attention_weights(x, mask=mask)
probs = torch.sigmoid(logits)[0].detach().cpu().numpy()
pred = {str(label_names[i]): float(probs[i]) for i in range(len(label_names))}
attn_vec = attn[0].detach().cpu().numpy().tolist()
return pred, [float(x) for x in attn_vec]
def _apply_mutations(sequence: str, mutations: Sequence[Tuple[int, str, str]]) -> str:
seq = list(sequence.upper().strip())
n = len(seq)
for pos, orig, mut in mutations:
p = int(pos)
if p < 1 or p > n:
raise ValueError(f"Mutation position {p} out of range for length {n}")
o = str(orig).upper()
m = str(mut).upper()
if seq[p - 1] != o:
raise ValueError(f"Original AA mismatch at {p}: sequence has {seq[p-1]!r}, mutation expects {o!r}")
seq[p - 1] = m
return "".join(seq)
def _risk_from_delta(abs_delta: float) -> str:
if abs_delta > 0.3:
return "high"
if abs_delta >= 0.15:
return "medium"
if abs_delta >= 0.05:
return "low"
return "none"
def _best_attention_index(sequence: str, attention: Sequence[float], allowed: set[str]) -> Optional[int]:
best_idx: Optional[int] = None
best_score = -1.0
n = min(len(sequence), len(attention))
for i in range(n):
aa = sequence[i]
if aa not in allowed:
continue
score = float(attention[i])
if score > best_score:
best_score = score
best_idx = i
return best_idx
def _top_attention_indices(
sequence: str,
attention: Sequence[float],
*,
k: int,
allowed: Optional[set[str]] = None,
exclude_target_aa: Optional[str] = None,
) -> List[int]:
ranked: List[Tuple[float, int]] = []
n = min(len(sequence), len(attention))
for i in range(n):
aa = sequence[i]
if allowed is not None and aa not in allowed:
continue
if exclude_target_aa is not None and aa == exclude_target_aa:
continue
ranked.append((float(attention[i]), i))
ranked.sort(key=lambda t: t[0], reverse=True)
return [idx for _score, idx in ranked[: max(1, int(k))]]
def _find_case(
df: pd.DataFrame,
tokenizer: Any,
esm: Any,
classifier: ResidueLocalizationClassifier,
label_names: Sequence[str],
device: torch.device,
target_label: str,
prob_min: float,
prob_max: float,
allowed_residues: set[str],
length_min: Optional[int] = None,
length_max: Optional[int] = None,
max_scan: int = 500,
) -> Tuple[str, str, Dict[str, float], List[float], int, str]:
scanned = 0
for _, row in df.iterrows():
if scanned >= max_scan:
break
acc = str(row["ACC"])
seq = str(row["Sequence"]).upper().strip()
if length_min is not None and len(seq) < int(length_min):
continue
if length_max is not None and len(seq) > int(length_max):
continue
scanned += 1
pred, attn = _predict_with_attention(seq, tokenizer, esm, classifier, label_names, device)
p = float(pred.get(target_label, 0.0))
if p < prob_min or p > prob_max:
continue
idx = _best_attention_index(seq, attn, allowed_residues)
if idx is None:
continue
return acc, seq, pred, attn, idx, seq[idx]
raise RuntimeError(
f"Could not find case for label={target_label!r}, prob in [{prob_min}, {prob_max}], "
f"allowed residues={sorted(allowed_residues)} after scanning {scanned} candidate proteins."
)
def main() -> None:
parser = argparse.ArgumentParser(description="Attention-targeted residue variant smoke tests.")
parser.add_argument("--classifier-path", type=Path, default=ROOT / "models" / "best_residue_model.pt")
parser.add_argument("--csv-path", type=Path, default=ROOT / "data" / "processed" / "deeploc_multilabel.csv")
parser.add_argument("--device", default="cuda")
parser.add_argument("--max-scan", type=int, default=500, help="Max proteins to scan while searching test cases.")
args = parser.parse_args()
classifier_path = args.classifier_path if args.classifier_path.is_absolute() else (ROOT / args.classifier_path).resolve()
csv_path = args.csv_path if args.csv_path.is_absolute() else (ROOT / args.csv_path).resolve()
if not classifier_path.is_file():
raise FileNotFoundError(f"Missing classifier: {classifier_path}")
if not csv_path.is_file():
raise FileNotFoundError(f"Missing dataset CSV: {csv_path}")
device, tokenizer, esm, classifier, label_names = _load_models(
classifier_path=classifier_path,
device_req=args.device,
)
df = _load_sequences_from_csv(csv_path)
# ---------------- Test A: moderate-confidence Nucleus ----------------
print("\n=== TEST A: Moderate-confidence Nucleus (0.5-0.8), mutate top-attention K/R -> A ===")
acc_n, seq_n, pred_n, attn_n, idx_n, aa_n = _find_case(
df,
tokenizer,
esm,
classifier,
label_names,
device,
target_label="Nucleus",
prob_min=0.5,
prob_max=0.8,
allowed_residues={"K", "R"},
max_scan=max(1, int(args.max_scan)),
)
print(f"Selected ACC={acc_n} | length={len(seq_n)} | P(Nucleus)={pred_n.get('Nucleus', 0.0):.4f}")
print(f"Top-attention basic residue: index {idx_n} (position {idx_n + 1}) = {aa_n}, attn={attn_n[idx_n]:.6f}")
mut_n = [(idx_n + 1, aa_n, "A")]
seq_n_mut = _apply_mutations(seq_n, mut_n)
pred_n_mut, _ = _predict_with_attention(seq_n_mut, tokenizer, esm, classifier, label_names, device)
print("\nOriginal vs mutant predictions (all locations) [Nucleus targeted]:")
deltas_nuc: Dict[str, float] = {}
for label in sorted(pred_n.keys()):
p0 = float(pred_n[label])
p1 = float(pred_n_mut[label])
d = p1 - p0
deltas_nuc[label] = d
print(f" {label:24s} {p0:.4f} -> {p1:.4f} (delta {d:+.4f})")
most_aff_nuc = max(deltas_nuc, key=lambda k: abs(deltas_nuc[k]))
max_abs_delta_nuc = abs(float(deltas_nuc[most_aff_nuc]))
print("\nClinical summary [Nucleus targeted]:")
print(
f"Single mutation {idx_n + 1}{aa_n}>A: most affected location={most_aff_nuc}, "
f"delta={deltas_nuc[most_aff_nuc]:+.4f}"
)
print(f"Mislocalization risk: {_risk_from_delta(max_abs_delta_nuc)}")
print(
f"Key question (single R/K->A): max |delta| = {max_abs_delta_nuc:.4f} "
f"=> {'YES' if max_abs_delta_nuc > 0.05 else 'NO'} (threshold > 0.05)"
)
# Additional stress test: mutate ALL K/R in N-terminal positions 1-30 to A.
n_term_end = min(30, len(seq_n))
n_term_mutations: List[Tuple[int, str, str]] = []
for idx in range(0, n_term_end): # idx is 0-based
aa = seq_n[idx]
if aa in {"K", "R"}:
n_term_mutations.append((idx + 1, aa, "A")) # predictor expects 1-based positions
print("\n--- N-terminal basic-signal wipeout test (positions 1-30) ---")
if not n_term_mutations:
print("No K/R residues found in positions 1-30; skipping combined N-terminal mutation test.")
return
print(f"Found {len(n_term_mutations)} K/R residues in positions 1-30; mutating all to A.")
print("Mutations:")
print(" " + ", ".join(f"{p}{o}>A" for p, o, _ in n_term_mutations))
nterm_mutant_sequence = _apply_mutations(seq_n, n_term_mutations)
predm_nterm, _ = _predict_with_attention(nterm_mutant_sequence, tokenizer, esm, classifier, label_names, device)
print("\nOriginal vs mutant predictions (all locations) [N-term combined mutation]:")
deltas_nterm: Dict[str, float] = {}
for label in sorted(pred_n.keys()):
p0 = float(pred_n[label])
p1 = float(predm_nterm[label])
d = p1 - p0
deltas_nterm[label] = d
print(f" {label:24s} {p0:.4f} -> {p1:.4f} (delta {d:+.4f})")
most_aff_n = max(deltas_nterm, key=lambda k: abs(deltas_nterm[k]))
print("\nClinical summary [N-term combined mutation]:")
print(
f"N-term wipeout: most affected location={most_aff_n}, "
f"delta={deltas_nterm[most_aff_n]:+.4f}"
)
print(
f"\nMislocalization risk [N-term combined mutation]: "
f"{_risk_from_delta(abs(float(deltas_nterm[most_aff_n])))}"
)
# ---------------- Test B: short membrane proteins ----------------
print("\n=== TEST B: Short Membrane protein (50-150 aa), mutate top-attention L/V/I/F -> D ===")
acc_m, seq_m, pred_m, attn_m, idx_m, aa_m = _find_case(
df,
tokenizer,
esm,
classifier,
label_names,
device,
target_label="Membrane",
prob_min=0.5,
prob_max=1.0,
allowed_residues={"L", "V", "I", "F"},
length_min=50,
length_max=150,
max_scan=max(1, int(args.max_scan)),
)
print(f"Selected ACC={acc_m} | length={len(seq_m)} | P(Membrane)={pred_m.get('Membrane', 0.0):.4f}")
print(
f"Top-attention hydrophobic residue: index {idx_m} (position {idx_m + 1}) = {aa_m}, "
f"attn={attn_m[idx_m]:.6f}"
)
mut_m = [(idx_m + 1, aa_m, "D")]
seq_m_mut = _apply_mutations(seq_m, mut_m)
pred_m_mut, _ = _predict_with_attention(seq_m_mut, tokenizer, esm, classifier, label_names, device)
print("\nOriginal vs mutant predictions (all locations) [Membrane targeted]:")
deltas_mem: Dict[str, float] = {}
for label in sorted(pred_m.keys()):
p0 = float(pred_m[label])
p1 = float(pred_m_mut[label])
d = p1 - p0
deltas_mem[label] = d
print(f" {label:24s} {p0:.4f} -> {p1:.4f} (delta {d:+.4f})")
most_aff_mem = max(deltas_mem, key=lambda k: abs(deltas_mem[k]))
print("\nClinical summary [Membrane targeted]:")
print(
f"Single mutation {idx_m + 1}{aa_m}>D: most affected location={most_aff_mem}, "
f"delta={deltas_mem[most_aff_mem]:+.4f}"
)
print(f"Mislocalization risk: {_risk_from_delta(abs(float(deltas_mem[most_aff_mem])))}")
# ---------------- Test C: same protein as Test B (ACC=Q6QNY1), top-3 hydrophobic -> D ----------------
print("\n=== TEST C: ACC=Q6QNY1, top-3 attention hydrophobic (L/V/I/F/W/M) -> D ===")
row_q = df[df["ACC"] == "Q6QNY1"]
if len(row_q) == 0:
print("ACC Q6QNY1 not found in CSV; skipping Test C and Test D.")
return
seq_q = str(row_q.iloc[0]["Sequence"]).upper().strip()
print(f"Selected ACC=Q6QNY1 | length={len(seq_q)}")
if len(seq_q) != 142:
print("Warning: expected length 142 for Q6QNY1, got " + str(len(seq_q)))
pred_q, attn_q = _predict_with_attention(seq_q, tokenizer, esm, classifier, label_names, device)
top3_h = _top_attention_indices(
seq_q,
attn_q,
k=3,
allowed={"L", "V", "I", "F", "W", "M"},
exclude_target_aa="D",
)
if len(top3_h) < 3:
print("Could not find 3 hydrophobic residues for Test C; skipping.")
else:
mut_c = [(i + 1, seq_q[i], "D") for i in top3_h]
print("Mutations:")
print(" " + ", ".join(f"{p}{o}>D (attn={attn_q[p-1]:.6f})" for p, o, _ in mut_c))
seq_q_c = _apply_mutations(seq_q, mut_c)
pred_q_c, _ = _predict_with_attention(seq_q_c, tokenizer, esm, classifier, label_names, device)
print("\nOriginal vs mutant predictions (all locations) [Test C]:")
deltas_c: Dict[str, float] = {}
for label in sorted(pred_q.keys()):
p0 = float(pred_q[label])
p1 = float(pred_q_c[label])
d = p1 - p0
deltas_c[label] = d
print(f" {label:24s} {p0:.4f} -> {p1:.4f} (delta {d:+.4f})")
most_aff_c = max(deltas_c, key=lambda k: abs(deltas_c[k]))
max_abs_c = abs(float(deltas_c[most_aff_c]))
print(
f"\nClinical summary [Test C]: most affected location={most_aff_c}, "
f"delta={deltas_c[most_aff_c]:+.4f}"
)
print(f"Mislocalization risk [Test C]: {_risk_from_delta(max_abs_c)}")
print(f"Key check [Test C]: max |delta| = {max_abs_c:.4f} => {'YES' if max_abs_c > 0.05 else 'NO'}")
# ---------------- Test D: same protein, top-5 attention residues (any AA) -> A ----------------
print("\n=== TEST D: ACC=Q6QNY1, top-5 attention residues (any AA) -> A ===")
top5_any = _top_attention_indices(
seq_q,
attn_q,
k=5,
allowed=None,
exclude_target_aa="A",
)
if len(top5_any) < 5:
print("Could not find 5 mutable residues for Test D; skipping.")
return
mut_d = [(i + 1, seq_q[i], "A") for i in top5_any]
print("Mutations:")
print(" " + ", ".join(f"{p}{o}>A (attn={attn_q[p-1]:.6f})" for p, o, _ in mut_d))
seq_q_d = _apply_mutations(seq_q, mut_d)
pred_q_d, _ = _predict_with_attention(seq_q_d, tokenizer, esm, classifier, label_names, device)
print("\nOriginal vs mutant predictions (all locations) [Test D]:")
deltas_d: Dict[str, float] = {}
for label in sorted(pred_q.keys()):
p0 = float(pred_q[label])
p1 = float(pred_q_d[label])
d = p1 - p0
deltas_d[label] = d
print(f" {label:24s} {p0:.4f} -> {p1:.4f} (delta {d:+.4f})")
most_aff_d = max(deltas_d, key=lambda k: abs(deltas_d[k]))
max_abs_d = abs(float(deltas_d[most_aff_d]))
print(
f"\nClinical summary [Test D]: most affected location={most_aff_d}, "
f"delta={deltas_d[most_aff_d]:+.4f}"
)
print(f"Mislocalization risk [Test D]: {_risk_from_delta(max_abs_d)}")
print(f"Key check [Test D]: max |delta| = {max_abs_d:.4f} => {'YES' if max_abs_d > 0.05 else 'NO'}")
if __name__ == "__main__":
main()