Spaces:
Running
Running
File size: 9,040 Bytes
cb6f1ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | """
Controlled side-by-side variant-effect comparison:
mean-pooled classifier vs residue-attention classifier.
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from typing import Any, Dict, List, Mapping, Sequence, Tuple
import pandas as pd
import torch
from transformers import AutoModel, AutoTokenizer
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from src.analysis.variant_effect import VariantEffectPredictor # noqa: E402
from src.models.residue_classifier import FALLBACK_LABEL_NAMES, ResidueLocalizationClassifier # noqa: E402
from src.utils.device import resolve_torch_device # noqa: E402
ESM_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
MAX_LENGTH = 1024
TESTS: Dict[str, List[Tuple[int, str, str]]] = {
"Test B": [(45, "F", "D")],
"Test C": [(45, "F", "D"), (9, "L", "D"), (8, "V", "D")],
"Test D": [(142, "R", "A"), (141, "K", "A"), (45, "F", "A"), (41, "C", "A"), (9, "L", "A")],
}
class ResidueVariantEffectPredictor:
"""Simple residue-level analog of VariantEffectPredictor.predict_variant_effect."""
def __init__(
self,
classifier_path: str | Path,
device: str | torch.device | None = None,
) -> None:
self.device = resolve_torch_device(device)
self.classifier_path = Path(classifier_path).expanduser().resolve()
if not self.classifier_path.is_file():
raise FileNotFoundError(f"Missing residue classifier checkpoint: {self.classifier_path}")
self.tokenizer = AutoTokenizer.from_pretrained(ESM_MODEL_NAME)
self.esm_model = AutoModel.from_pretrained(
ESM_MODEL_NAME,
attn_implementation="eager",
ignore_mismatched_sizes=True,
)
self.esm_model.eval().to(self.device)
ckpt = torch.load(self.classifier_path, map_location="cpu")
if not isinstance(ckpt, Mapping):
raise ValueError("Unsupported residue checkpoint format.")
state = ckpt.get("state_dict", ckpt.get("model_state_dict", ckpt))
embedding_dim = int(ckpt.get("embedding_dim", 1280))
num_labels = int(ckpt.get("num_labels", 11))
label_names = list(ckpt.get("label_names") or FALLBACK_LABEL_NAMES[:num_labels])
if len(label_names) != num_labels:
raise ValueError("Residue checkpoint label_names length mismatch.")
self.label_names = label_names
self.classifier = ResidueLocalizationClassifier(
embedding_dim=embedding_dim,
num_labels=num_labels,
label_names=label_names,
dropout=float(ckpt.get("dropout", 0.3)),
num_heads=int(ckpt.get("num_heads", 4)),
)
self.classifier.load_state_dict(state, strict=True)
self.classifier.eval().to(self.device)
@torch.inference_mode()
def _predict_proba(self, sequence: str) -> Dict[str, float]:
seq = sequence.upper().strip()
toks = self.tokenizer(
[seq],
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_LENGTH,
add_special_tokens=True,
)
toks = {k: v.to(self.device) for k, v in toks.items()}
out = self.esm_model(**toks, return_dict=True)
hidden = out.last_hidden_state
attn_mask = toks["attention_mask"]
valid_len = int(attn_mask[0].sum().item())
if valid_len < 3:
raise ValueError("Tokenized sequence too short.")
core = hidden[0, 1 : valid_len - 1, :].float()
x = core.unsqueeze(0)
mask = torch.ones((1, core.shape[0]), dtype=torch.bool, device=self.device)
logits = self.classifier(x, mask=mask)
probs = torch.sigmoid(logits)[0].detach().cpu().tolist()
return {self.label_names[i]: float(probs[i]) for i in range(len(self.label_names))}
@staticmethod
def _apply_mutations(sequence: str, mutations: Sequence[Tuple[int, str, str]]) -> str:
seq = list(sequence.upper().strip())
n = len(seq)
for pos, orig, mut in mutations:
if pos < 1 or pos > n:
raise ValueError(f"Mutation position {pos} out of range for length {n}")
if seq[pos - 1] != orig:
raise ValueError(
f"Original AA mismatch at {pos}: expected {seq[pos - 1]!r}, got mutation original {orig!r}"
)
seq[pos - 1] = mut
return "".join(seq)
def predict_variant_effect(
self,
original_sequence: str,
mutations: Sequence[Tuple[int, str, str]],
) -> Dict[str, Any]:
p0 = self._predict_proba(original_sequence)
seqm = self._apply_mutations(original_sequence, mutations)
pm = self._predict_proba(seqm)
deltas = {k: float(pm[k] - p0[k]) for k in self.label_names}
return {
"original_predictions": p0,
"mutant_predictions": pm,
"deltas": deltas,
"mutant_sequence": seqm,
}
def _load_q6qny1_sequence(csv_path: Path) -> str:
df = pd.read_csv(csv_path)
req = {"ACC", "Sequence"}
if not req.issubset(df.columns):
raise ValueError(f"CSV must contain {sorted(req)} columns.")
rows = df[df["ACC"].astype(str) == "Q6QNY1"]
if len(rows) == 0:
raise ValueError("ACC Q6QNY1 not found in CSV.")
seq = str(rows.iloc[0]["Sequence"]).upper().strip()
if len(seq) != 142:
print(f"Warning: expected length 142 for Q6QNY1, got {len(seq)}")
return seq
def _print_test_table(
test_name: str,
mean_deltas: Dict[str, float],
residue_deltas: Dict[str, float],
) -> None:
print(f"\n=== {test_name} ===")
print("Location | Mean-pooled delta | Residue delta | Improvement")
print("--------------------------+-------------------+------------------+------------")
for loc in sorted(mean_deltas.keys()):
m = float(mean_deltas[loc])
r = float(residue_deltas.get(loc, 0.0))
denom = abs(m)
if denom < 1e-9:
imp = "inf" if abs(r) > 0 else "1.0x"
else:
imp = f"{(abs(r) / denom):.1f}x"
print(f"{loc:26s} | {m:+.4f} | {r:+.4f} | {imp}")
def main() -> None:
parser = argparse.ArgumentParser(description="Controlled comparison: mean-pooled vs residue variant-effect sensitivity.")
parser.add_argument("--csv-path", type=Path, default=ROOT / "data" / "processed" / "deeploc_multilabel.csv")
parser.add_argument("--mean-model", type=Path, default=ROOT / "models" / "best_model.pt")
parser.add_argument("--residue-model", type=Path, default=ROOT / "models" / "best_residue_model.pt")
parser.add_argument("--device", default="cuda")
args = parser.parse_args()
csv_path = args.csv_path if args.csv_path.is_absolute() else (ROOT / args.csv_path).resolve()
mean_model = args.mean_model if args.mean_model.is_absolute() else (ROOT / args.mean_model).resolve()
residue_model = args.residue_model if args.residue_model.is_absolute() else (ROOT / args.residue_model).resolve()
if not csv_path.is_file():
raise FileNotFoundError(f"Missing CSV: {csv_path}")
if not mean_model.is_file():
raise FileNotFoundError(f"Missing mean-pooled checkpoint: {mean_model}")
if not residue_model.is_file():
raise FileNotFoundError(f"Missing residue checkpoint: {residue_model}")
sequence = _load_q6qny1_sequence(csv_path)
print(f"Protein: ACC=Q6QNY1 | length={len(sequence)}")
mean_predictor = VariantEffectPredictor(classifier_path=mean_model, device=args.device)
residue_predictor = ResidueVariantEffectPredictor(classifier_path=residue_model, device=args.device)
all_abs_mean: List[float] = []
all_abs_res: List[float] = []
for test_name, muts in TESTS.items():
print("Mutations: " + ", ".join(f"{p}{o}>{m}" for p, o, m in muts))
out_mean = mean_predictor.predict_variant_effect(sequence, muts)
out_res = residue_predictor.predict_variant_effect(sequence, muts)
d_mean = {k: float(v) for k, v in (out_mean.get("deltas") or {}).items()}
d_res = {k: float(v) for k, v in (out_res.get("deltas") or {}).items()}
_print_test_table(test_name, d_mean, d_res)
all_abs_mean.extend(abs(x) for x in d_mean.values())
all_abs_res.extend(abs(x) for x in d_res.values())
avg_abs_mean = sum(all_abs_mean) / max(1, len(all_abs_mean))
avg_abs_res = sum(all_abs_res) / max(1, len(all_abs_res))
if avg_abs_mean < 1e-12:
overall_ratio_txt = "inf"
else:
overall_ratio_txt = f"{(avg_abs_res / avg_abs_mean):.2f}x"
print("\n=== Summary ===")
print(f"Average |delta| (mean-pooled): {avg_abs_mean:.6f}")
print(f"Average |delta| (residue): {avg_abs_res:.6f}")
print(f"Average sensitivity improvement across all locations/tests: {overall_ratio_txt}")
if __name__ == "__main__":
main()
|