| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import numpy as np |
| from collections import Counter |
| from rdkit import Chem |
| from rdkit.Chem import Descriptors |
| from rdkit import RDLogger |
| RDLogger.DisableLog('rdApp.*') |
|
|
| STANDARD_AA = set('ACDEFGHIKLMNPQRSTVWY') |
|
|
|
|
| |
|
|
| def check_sequence(seq: str) -> tuple: |
| """Returns (score 0-100, list of warning strings).""" |
| seq = seq.strip().upper() |
| warn = [] |
|
|
| if not seq: |
| return 0.0, ["EMPTY_SEQUENCE"] |
|
|
| invalid_frac = sum(1 for c in seq if c not in STANDARD_AA) / len(seq) |
| if invalid_frac > 0.30: |
| return 0.0, [f"NOT_A_PROTEIN: {invalid_frac:.0%} non-standard characters"] |
|
|
| clean = ''.join(c for c in seq if c in STANDARD_AA) |
| if not clean: |
| return 0.0, ["NOT_A_PROTEIN: no standard amino acids found"] |
|
|
| score = 100.0 |
|
|
| |
| counts = Counter(clean) |
| top_frac = counts.most_common(1)[0][1] / len(clean) |
| if top_frac > 0.40: |
| warn.append(f"LOW_COMPLEXITY: single AA = {top_frac:.0%} of sequence " |
| f"(likely poly-X repeat β prediction unreliable)") |
| score -= 50 |
|
|
| |
| freqs = np.array(list(counts.values())) / len(clean) |
| entropy = -np.sum(freqs * np.log2(freqs + 1e-10)) |
| if entropy < 2.5: |
| warn.append(f"LOW_ENTROPY: {entropy:.2f} bits β low complexity sequence") |
| score -= 25 |
|
|
| |
| if len(counts) < 5: |
| warn.append(f"LOW_DIVERSITY: only {len(counts)} unique amino acids") |
| score -= 25 |
|
|
| if len(seq) < 50: |
| warn.append(f"SHORT: length {len(seq)} β may be a peptide, not a drug target") |
| score -= 15 |
|
|
| if invalid_frac > 0: |
| warn.append(f"NON_STANDARD: {invalid_frac:.0%} non-standard characters present") |
| score -= 10 |
|
|
| return max(0.0, score), warn |
|
|
|
|
| |
|
|
| def check_ligand(smiles: str) -> tuple: |
| """Returns (score 0-100, list of warning strings).""" |
| warn = [] |
| if not smiles or not smiles.strip(): |
| return 0.0, ["EMPTY_SMILES"] |
|
|
| mol = Chem.MolFromSmiles(smiles.strip()) |
| if mol is None: |
| return 0.0, ["INVALID_SMILES: RDKit could not parse this string"] |
|
|
| score = 100.0 |
|
|
| mw = Descriptors.MolWt(mol) |
| if mw < 100: |
| warn.append(f"LOW_MW: {mw:.1f} Da β likely a fragment or solvent") |
| score -= 40 |
| elif mw > 1000: |
| warn.append(f"HIGH_MW: {mw:.1f} Da β may be outside training distribution") |
| score -= 20 |
|
|
| n_heavy = mol.GetNumHeavyAtoms() |
| if n_heavy < 5: |
| warn.append(f"TOO_SMALL: only {n_heavy} heavy atoms") |
| score -= 40 |
|
|
| allowed = {1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53} |
| exotic = {a.GetSymbol() for a in mol.GetAtoms() |
| if a.GetAtomicNum() not in allowed and a.GetAtomicNum() != 0} |
| if exotic: |
| warn.append(f"EXOTIC_ATOMS: {', '.join(sorted(exotic))} β rare in training data") |
| score -= 20 |
|
|
| return max(0.0, score), warn |
|
|
|
|
| |
|
|
| class EmbeddingAD: |
| """ |
| kNN applicability domain in ESM embedding space. |
| Flags proteins far from the training distribution. |
| """ |
| def __init__(self, k: int = 5, percentile: float = 95.0): |
| self.k = k |
| self.percentile = percentile |
| self.fitted = False |
|
|
| def fit(self, train_embeddings: np.ndarray): |
| from sklearn.neighbors import NearestNeighbors |
| print(f"Fitting Embedding AD on {len(train_embeddings)} proteins...") |
| self.nn = NearestNeighbors(n_neighbors=self.k + 1, |
| metric='cosine', n_jobs=-1) |
| self.nn.fit(train_embeddings.astype(np.float32)) |
| dists, _ = self.nn.kneighbors(train_embeddings.astype(np.float32)) |
| knn_dists = dists[:, 1:].mean(axis=1) |
| self.threshold = np.percentile(knn_dists, self.percentile) |
| self.fitted = True |
| print(f" AD threshold ({self.percentile}th pct): {self.threshold:.4f}") |
| return self |
|
|
| def score(self, embedding: np.ndarray) -> tuple: |
| """Returns (distance, score 0-100, in_domain bool).""" |
| if not self.fitted: |
| return None, 100.0, True |
| dists, _ = self.nn.kneighbors( |
| embedding.reshape(1, -1).astype(np.float32), n_neighbors=self.k |
| ) |
| dist = float(dists[0].mean()) |
| in_domain = dist <= self.threshold |
| if dist <= self.threshold: |
| ad_score = 50 + 50 * (1 - dist / self.threshold) |
| else: |
| ad_score = max(0.0, 50 * (1 - (dist - self.threshold) / self.threshold)) |
| return dist, ad_score, in_domain |
|
|
|
|
| |
|
|
| def confidence_report(seq: str, smiles: str, |
| embedding: np.ndarray = None, |
| ad_model: EmbeddingAD = None) -> dict: |
| seq_score, seq_warn = check_sequence(seq) |
| lig_score, lig_warn = check_ligand(smiles) |
|
|
| ad_score, ad_dist, in_domain = 100.0, None, True |
| if embedding is not None and ad_model is not None: |
| ad_dist, ad_score, in_domain = ad_model.score(embedding) |
| if not in_domain: |
| seq_warn.append( |
| f"OUT_OF_DOMAIN: protein distance={ad_dist:.3f} " |
| f"(threshold={ad_model.threshold:.3f})" |
| ) |
|
|
| all_warn = seq_warn + lig_warn |
| w_ad = 0.20 if ad_model else 0.0 |
| w_seq = 0.55 if not ad_model else 0.45 |
| w_lig = 1.0 - w_seq - w_ad |
| overall = w_seq * seq_score + w_lig * lig_score + w_ad * ad_score |
|
|
| if overall >= 70 and seq_score >= 60 and lig_score >= 60: |
| flag = 'RELIABLE' |
| elif overall >= 40: |
| flag = 'UNCERTAIN' |
| else: |
| flag = 'UNRELIABLE' |
|
|
| return { |
| 'flag': flag, |
| 'overall': round(overall, 1), |
| 'seq_score': round(seq_score, 1), |
| 'lig_score': round(lig_score, 1), |
| 'ad_score': round(ad_score, 1), |
| 'in_domain': in_domain, |
| 'warnings': all_warn, |
| } |
|
|