VeloBind / src /applicability_domain.py
ym59's picture
Upload src/applicability_domain.py with huggingface_hub
d0d761f verified
# src/applicability_domain.py
#
# Three-layer applicability domain check.
# Flags garbage inputs before prediction is shown to the user.
#
# Layer 1 β€” Sequence sanity (catches poly-A, random chars, empty)
# Layer 2 β€” Ligand sanity (catches invalid SMILES, non-drug-like)
# Layer 3 β€” Embedding AD (catches proteins far from training dist)
import numpy as np
from collections import Counter
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
STANDARD_AA = set('ACDEFGHIKLMNPQRSTVWY')
# ── Layer 1: Sequence ─────────────────────────────────────────────────
def check_sequence(seq: str) -> tuple:
"""Returns (score 0-100, list of warning strings)."""
seq = seq.strip().upper()
warn = []
if not seq:
return 0.0, ["EMPTY_SEQUENCE"]
invalid_frac = sum(1 for c in seq if c not in STANDARD_AA) / len(seq)
if invalid_frac > 0.30:
return 0.0, [f"NOT_A_PROTEIN: {invalid_frac:.0%} non-standard characters"]
clean = ''.join(c for c in seq if c in STANDARD_AA)
if not clean:
return 0.0, ["NOT_A_PROTEIN: no standard amino acids found"]
score = 100.0
# Low complexity (poly-X)
counts = Counter(clean)
top_frac = counts.most_common(1)[0][1] / len(clean)
if top_frac > 0.40:
warn.append(f"LOW_COMPLEXITY: single AA = {top_frac:.0%} of sequence "
f"(likely poly-X repeat β€” prediction unreliable)")
score -= 50
# Shannon entropy
freqs = np.array(list(counts.values())) / len(clean)
entropy = -np.sum(freqs * np.log2(freqs + 1e-10))
if entropy < 2.5:
warn.append(f"LOW_ENTROPY: {entropy:.2f} bits β€” low complexity sequence")
score -= 25
# Unique AAs
if len(counts) < 5:
warn.append(f"LOW_DIVERSITY: only {len(counts)} unique amino acids")
score -= 25
if len(seq) < 50:
warn.append(f"SHORT: length {len(seq)} β€” may be a peptide, not a drug target")
score -= 15
if invalid_frac > 0:
warn.append(f"NON_STANDARD: {invalid_frac:.0%} non-standard characters present")
score -= 10
return max(0.0, score), warn
# ── Layer 2: Ligand ───────────────────────────────────────────────────
def check_ligand(smiles: str) -> tuple:
"""Returns (score 0-100, list of warning strings)."""
warn = []
if not smiles or not smiles.strip():
return 0.0, ["EMPTY_SMILES"]
mol = Chem.MolFromSmiles(smiles.strip())
if mol is None:
return 0.0, ["INVALID_SMILES: RDKit could not parse this string"]
score = 100.0
mw = Descriptors.MolWt(mol)
if mw < 100:
warn.append(f"LOW_MW: {mw:.1f} Da β€” likely a fragment or solvent")
score -= 40
elif mw > 1000:
warn.append(f"HIGH_MW: {mw:.1f} Da β€” may be outside training distribution")
score -= 20
n_heavy = mol.GetNumHeavyAtoms()
if n_heavy < 5:
warn.append(f"TOO_SMALL: only {n_heavy} heavy atoms")
score -= 40
allowed = {1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53}
exotic = {a.GetSymbol() for a in mol.GetAtoms()
if a.GetAtomicNum() not in allowed and a.GetAtomicNum() != 0}
if exotic:
warn.append(f"EXOTIC_ATOMS: {', '.join(sorted(exotic))} β€” rare in training data")
score -= 20
return max(0.0, score), warn
# ── Layer 3: Embedding AD ─────────────────────────────────────────────
class EmbeddingAD:
"""
kNN applicability domain in ESM embedding space.
Flags proteins far from the training distribution.
"""
def __init__(self, k: int = 5, percentile: float = 95.0):
self.k = k
self.percentile = percentile
self.fitted = False
def fit(self, train_embeddings: np.ndarray):
from sklearn.neighbors import NearestNeighbors
print(f"Fitting Embedding AD on {len(train_embeddings)} proteins...")
self.nn = NearestNeighbors(n_neighbors=self.k + 1,
metric='cosine', n_jobs=-1)
self.nn.fit(train_embeddings.astype(np.float32))
dists, _ = self.nn.kneighbors(train_embeddings.astype(np.float32))
knn_dists = dists[:, 1:].mean(axis=1)
self.threshold = np.percentile(knn_dists, self.percentile)
self.fitted = True
print(f" AD threshold ({self.percentile}th pct): {self.threshold:.4f}")
return self
def score(self, embedding: np.ndarray) -> tuple:
"""Returns (distance, score 0-100, in_domain bool)."""
if not self.fitted:
return None, 100.0, True
dists, _ = self.nn.kneighbors(
embedding.reshape(1, -1).astype(np.float32), n_neighbors=self.k
)
dist = float(dists[0].mean())
in_domain = dist <= self.threshold
if dist <= self.threshold:
ad_score = 50 + 50 * (1 - dist / self.threshold)
else:
ad_score = max(0.0, 50 * (1 - (dist - self.threshold) / self.threshold))
return dist, ad_score, in_domain
# ── Combined report ───────────────────────────────────────────────────
def confidence_report(seq: str, smiles: str,
embedding: np.ndarray = None,
ad_model: EmbeddingAD = None) -> dict:
seq_score, seq_warn = check_sequence(seq)
lig_score, lig_warn = check_ligand(smiles)
ad_score, ad_dist, in_domain = 100.0, None, True
if embedding is not None and ad_model is not None:
ad_dist, ad_score, in_domain = ad_model.score(embedding)
if not in_domain:
seq_warn.append(
f"OUT_OF_DOMAIN: protein distance={ad_dist:.3f} "
f"(threshold={ad_model.threshold:.3f})"
)
all_warn = seq_warn + lig_warn
w_ad = 0.20 if ad_model else 0.0
w_seq = 0.55 if not ad_model else 0.45
w_lig = 1.0 - w_seq - w_ad
overall = w_seq * seq_score + w_lig * lig_score + w_ad * ad_score
if overall >= 70 and seq_score >= 60 and lig_score >= 60:
flag = 'RELIABLE'
elif overall >= 40:
flag = 'UNCERTAIN'
else:
flag = 'UNRELIABLE'
return {
'flag': flag,
'overall': round(overall, 1),
'seq_score': round(seq_score, 1),
'lig_score': round(lig_score, 1),
'ad_score': round(ad_score, 1),
'in_domain': in_domain,
'warnings': all_warn,
}