MultiTaskTox / src /features.py
mschuh's picture
Added first version
94b1553 verified
from __future__ import annotations
from pathlib import Path
from typing import Dict, Sequence
import numpy as np
import pandas as pd
from rdkit import DataStructs
from rdkit.Chem import AllChem
from .constants import CANONICAL_SMILES_COLUMN
from .preprocess import MoleculeBatch, filter_dataframe_by_mask, standardize_smiles
try:
from map4 import MAP4Calculator # type: ignore
except Exception: # pragma: no cover - optional dependency
MAP4Calculator = None
class FingerprintFeaturizer:
"""Compute molecular fingerprints with optional caching."""
def __init__(self, feature_config: Dict):
self.config = feature_config
self.fingerprint_type = feature_config.get("type", "ecfp").lower()
self.radius = feature_config.get("radius", 2)
self.n_bits = feature_config.get("n_bits", 1024)
self.map4_dim = feature_config.get("map4_dim", 1024)
self.use_counts = feature_config.get("use_counts", False)
cache_dir = feature_config.get("cache_dir")
self.cache_dir = Path(cache_dir) if cache_dir else None
if self.cache_dir:
self.cache_dir.mkdir(parents=True, exist_ok=True)
def featurize_dataframe(self, df: pd.DataFrame, split_name: str):
cache_payload = self._load_cache(split_name)
if cache_payload is not None:
mask = cache_payload["mask"]
canonical_smiles = cache_payload["canonical_smiles"].tolist()
features = cache_payload["features"]
clean_df = filter_dataframe_by_mask(df, mask, canonical_smiles)
return clean_df, features
batch = standardize_smiles(df["smiles"].tolist())
clean_df = filter_dataframe_by_mask(df, batch.mask, batch.canonical_smiles)
features = self._compute_fingerprints(batch.mols)
self._write_cache(split_name, batch.mask, batch.canonical_smiles, features)
return clean_df, features
def featurize_smiles(self, smiles: Sequence[str]) -> tuple[MoleculeBatch, np.ndarray]:
batch = standardize_smiles(smiles)
features = self._compute_fingerprints(batch.mols)
return batch, features
def _cache_path(self, split_name: str) -> Path | None:
if self.cache_dir is None:
return None
return self.cache_dir / f"{split_name}_{self.fingerprint_type}.npz"
def _load_cache(self, split_name: str):
cache_path = self._cache_path(split_name)
if cache_path is None or not cache_path.exists():
return None
return np.load(cache_path, allow_pickle=True)
def _write_cache(self, split_name: str, mask, canonical_smiles, features):
cache_path = self._cache_path(split_name)
if cache_path is None:
return
np.savez(
cache_path,
mask=mask,
canonical_smiles=np.array(canonical_smiles, dtype=object),
features=features,
)
def _compute_fingerprints(self, mols):
if not mols:
dim = self._fingerprint_dimension()
return np.zeros((0, dim), dtype=np.float32)
if self.fingerprint_type == "ecfp":
return self._compute_ecfp(mols)
if self.fingerprint_type == "map4":
return self._compute_map4(mols)
raise ValueError(f"Unsupported fingerprint type: {self.fingerprint_type}")
def _fingerprint_dimension(self) -> int:
if self.fingerprint_type == "map4":
return self.map4_dim
return self.n_bits
def _compute_ecfp(self, mols):
fingerprints = np.zeros((len(mols), self.n_bits), dtype=np.float32)
for idx, mol in enumerate(mols):
if self.use_counts:
fp = AllChem.GetMorganFingerprint(mol, self.radius)
arr = np.zeros(self.n_bits, dtype=np.float32)
for bit, value in fp.GetNonzeroElements().items():
arr[bit % self.n_bits] += value
else:
bitvect = AllChem.GetMorganFingerprintAsBitVect(
mol,
self.radius,
nBits=self.n_bits,
)
arr = np.zeros(self.n_bits, dtype=np.float32)
DataStructs.ConvertToNumpyArray(bitvect, arr)
fingerprints[idx] = arr
return fingerprints
def _compute_map4(self, mols):
if MAP4Calculator is None:
raise ImportError(
"MAP4 fingerprint requested but the `map4` package is not installed. "
"Install it via `pip install map4` or switch features.type to 'ecfp'."
)
calc = MAP4Calculator(dimensions=self.map4_dim)
fingerprints = np.zeros((len(mols), self.map4_dim), dtype=np.float32)
for idx, mol in enumerate(mols):
vec = np.array(calc.calculate(mol), dtype=np.float32)
fingerprints[idx] = vec
return fingerprints