ddi / src /training /molecular_features.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
"""RDKit-based molecular feature extraction utilities.
This module supports both per-drug and pair-level features for DDI modeling.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
import numpy as np
from sklearn.preprocessing import StandardScaler
try:
from rdkit import DataStructs
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
except Exception: # pragma: no cover
DataStructs = None # type: ignore
Chem = None # type: ignore
AllChem = None # type: ignore
Descriptors = None # type: ignore
logger = logging.getLogger("medcare_ddi.molfeat")
@dataclass
class MoleculeFeatureConfig:
n_bits: int = 1024
radius: int = 2
normalize_descriptors: bool = True
DESCRIPTOR_NAMES = [
'MolWt',
'LogP',
'NumHDonors',
'NumHAcceptors',
'TPSA',
'NumAtoms',
'NumRings',
'NumRotatableBonds',
]
PAIR_SIMILARITY_NAMES = [
'tanimoto',
'dice',
'cosine',
'both_valid',
'any_invalid',
]
def _safe_mol(smiles: str):
if Chem is None:
raise RuntimeError("RDKit not installed. Install rdkit-pypi or use conda.")
try:
mol = Chem.MolFromSmiles(smiles or '')
return mol
except Exception:
return None
def _desc_vector(mol) -> np.ndarray:
if mol is None:
return np.zeros((len(DESCRIPTOR_NAMES),), dtype=np.float32)
return np.array(
[
float(Descriptors.MolWt(mol)),
float(Descriptors.MolLogP(mol)),
float(Descriptors.NumHDonors(mol)),
float(Descriptors.NumHAcceptors(mol)),
float(Descriptors.TPSA(mol)),
float(mol.GetNumAtoms()),
float(Descriptors.RingCount(mol)),
float(Descriptors.NumRotatableBonds(mol)),
],
dtype=np.float32,
)
def _fingerprint(mol, radius: int, n_bits: int) -> np.ndarray:
if mol is None:
return np.zeros((n_bits,), dtype=np.float32)
bitvect = AllChem.GetMorganFingerprintAsBitVect(mol, radius, n_bits)
arr = np.zeros((n_bits,), dtype=np.int8)
DataStructs.ConvertToNumpyArray(bitvect, arr)
return arr.astype(np.float32)
def _tanimoto(mol_a, mol_b, radius: int, n_bits: int) -> float:
if mol_a is None or mol_b is None:
return 0.0
fp_a = AllChem.GetMorganFingerprintAsBitVect(mol_a, radius, n_bits)
fp_b = AllChem.GetMorganFingerprintAsBitVect(mol_b, radius, n_bits)
return float(DataStructs.TanimotoSimilarity(fp_a, fp_b))
def _pair_similarity_features(mol_a, mol_b, radius: int, n_bits: int) -> np.ndarray:
if mol_a is None or mol_b is None:
return np.array([0.0, 0.0, 0.0, 0.0, 1.0], dtype=np.float32)
fp_a = AllChem.GetMorganFingerprintAsBitVect(mol_a, radius, n_bits)
fp_b = AllChem.GetMorganFingerprintAsBitVect(mol_b, radius, n_bits)
tanimoto = float(DataStructs.TanimotoSimilarity(fp_a, fp_b))
dice = float(DataStructs.DiceSimilarity(fp_a, fp_b))
cosine = float(DataStructs.CosineSimilarity(fp_a, fp_b))
return np.array([tanimoto, dice, cosine, 1.0, 0.0], dtype=np.float32)
def smiles_to_features(smiles_list: List[str], n_bits: int = 1024, radius: int = 2) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
"""Convert a list of SMILES to fingerprint vectors and descriptor metadata.
Returns:
X: np.ndarray shape (N, n_bits + len(DESCRIPTOR_NAMES))
meta: list[dict] with descriptors per molecule
"""
fps: List[np.ndarray] = []
descs: List[np.ndarray] = []
metas = []
for s in smiles_list:
mol = _safe_mol(s)
if mol is None:
logger.warning(f"Invalid SMILES: {s}")
fp = np.zeros((n_bits,), dtype=np.float32)
dvec = np.zeros((len(DESCRIPTOR_NAMES),), dtype=np.float32)
meta = {"valid": False}
else:
fp = _fingerprint(mol, radius=radius, n_bits=n_bits)
dvec = _desc_vector(mol)
meta = {
"valid": True,
**{k: float(v) for k, v in zip(DESCRIPTOR_NAMES, dvec.tolist())},
}
fps.append(fp)
descs.append(dvec)
metas.append(meta)
X_fp = np.vstack(fps)
X_meta = np.vstack(descs)
X = np.hstack([X_fp, X_meta])
return X, metas
class MolecularFeatureExtractor:
"""Pair-level molecular feature extractor with descriptor normalization."""
def __init__(self, config: MoleculeFeatureConfig | None = None):
self.config = config or MoleculeFeatureConfig()
self.scaler = StandardScaler()
self._is_fitted = False
def fit(self, smiles_pairs: List[Tuple[str, str]]) -> None:
desc_rows: List[np.ndarray] = []
for s_a, s_b in smiles_pairs:
mol_a = _safe_mol(s_a)
mol_b = _safe_mol(s_b)
d_a = _desc_vector(mol_a)
d_b = _desc_vector(mol_b)
d_delta = np.abs(d_a - d_b)
sim = _pair_similarity_features(mol_a, mol_b, self.config.radius, self.config.n_bits)
desc_rows.append(np.concatenate([d_a, d_b, d_delta, sim], axis=0))
matrix = np.vstack(desc_rows) if desc_rows else np.zeros((0, len(DESCRIPTOR_NAMES) * 3 + len(PAIR_SIMILARITY_NAMES)), dtype=np.float32)
if matrix.shape[0] > 0:
self.scaler.fit(matrix)
self._is_fitted = True
def transform(self, smiles_pairs: List[Tuple[str, str]]) -> np.ndarray:
rows: List[np.ndarray] = []
for s_a, s_b in smiles_pairs:
mol_a = _safe_mol(s_a)
mol_b = _safe_mol(s_b)
fp_a = _fingerprint(mol_a, self.config.radius, self.config.n_bits)
fp_b = _fingerprint(mol_b, self.config.radius, self.config.n_bits)
fp_pair = np.abs(fp_a - fp_b)
d_a = _desc_vector(mol_a)
d_b = _desc_vector(mol_b)
d_delta = np.abs(d_a - d_b)
sim = _pair_similarity_features(mol_a, mol_b, self.config.radius, self.config.n_bits)
desc = np.concatenate([d_a, d_b, d_delta, sim], axis=0)
if self.config.normalize_descriptors and self._is_fitted:
desc = self.scaler.transform(desc.reshape(1, -1)).reshape(-1).astype(np.float32)
rows.append(np.concatenate([fp_pair, desc], axis=0).astype(np.float32))
return np.vstack(rows) if rows else np.zeros((0, self.config.n_bits + (len(DESCRIPTOR_NAMES) * 3 + len(PAIR_SIMILARITY_NAMES))), dtype=np.float32)
def fit_transform(self, smiles_pairs: List[Tuple[str, str]]) -> np.ndarray:
self.fit(smiles_pairs)
return self.transform(smiles_pairs)
if __name__ == '__main__':
test_smiles = ['CC(=O)OC1=CC=CC=C1C(=O)O', 'C1=CC=CC=C1', 'INVALID_SMILES']
X, metas = smiles_to_features(test_smiles)
print('X shape:', X.shape)
print(metas)