Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Dict, Sequence | |
| import numpy as np | |
| import pandas as pd | |
| from rdkit import DataStructs | |
| from rdkit.Chem import AllChem | |
| from .constants import CANONICAL_SMILES_COLUMN | |
| from .preprocess import MoleculeBatch, filter_dataframe_by_mask, standardize_smiles | |
| try: | |
| from map4 import MAP4Calculator # type: ignore | |
| except Exception: # pragma: no cover - optional dependency | |
| MAP4Calculator = None | |
| class FingerprintFeaturizer: | |
| """Compute molecular fingerprints with optional caching.""" | |
| def __init__(self, feature_config: Dict): | |
| self.config = feature_config | |
| self.fingerprint_type = feature_config.get("type", "ecfp").lower() | |
| self.radius = feature_config.get("radius", 2) | |
| self.n_bits = feature_config.get("n_bits", 1024) | |
| self.map4_dim = feature_config.get("map4_dim", 1024) | |
| self.use_counts = feature_config.get("use_counts", False) | |
| cache_dir = feature_config.get("cache_dir") | |
| self.cache_dir = Path(cache_dir) if cache_dir else None | |
| if self.cache_dir: | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| def featurize_dataframe(self, df: pd.DataFrame, split_name: str): | |
| cache_payload = self._load_cache(split_name) | |
| if cache_payload is not None: | |
| mask = cache_payload["mask"] | |
| canonical_smiles = cache_payload["canonical_smiles"].tolist() | |
| features = cache_payload["features"] | |
| clean_df = filter_dataframe_by_mask(df, mask, canonical_smiles) | |
| return clean_df, features | |
| batch = standardize_smiles(df["smiles"].tolist()) | |
| clean_df = filter_dataframe_by_mask(df, batch.mask, batch.canonical_smiles) | |
| features = self._compute_fingerprints(batch.mols) | |
| self._write_cache(split_name, batch.mask, batch.canonical_smiles, features) | |
| return clean_df, features | |
| def featurize_smiles(self, smiles: Sequence[str]) -> tuple[MoleculeBatch, np.ndarray]: | |
| batch = standardize_smiles(smiles) | |
| features = self._compute_fingerprints(batch.mols) | |
| return batch, features | |
| def _cache_path(self, split_name: str) -> Path | None: | |
| if self.cache_dir is None: | |
| return None | |
| return self.cache_dir / f"{split_name}_{self.fingerprint_type}.npz" | |
| def _load_cache(self, split_name: str): | |
| cache_path = self._cache_path(split_name) | |
| if cache_path is None or not cache_path.exists(): | |
| return None | |
| return np.load(cache_path, allow_pickle=True) | |
| def _write_cache(self, split_name: str, mask, canonical_smiles, features): | |
| cache_path = self._cache_path(split_name) | |
| if cache_path is None: | |
| return | |
| np.savez( | |
| cache_path, | |
| mask=mask, | |
| canonical_smiles=np.array(canonical_smiles, dtype=object), | |
| features=features, | |
| ) | |
| def _compute_fingerprints(self, mols): | |
| if not mols: | |
| dim = self._fingerprint_dimension() | |
| return np.zeros((0, dim), dtype=np.float32) | |
| if self.fingerprint_type == "ecfp": | |
| return self._compute_ecfp(mols) | |
| if self.fingerprint_type == "map4": | |
| return self._compute_map4(mols) | |
| raise ValueError(f"Unsupported fingerprint type: {self.fingerprint_type}") | |
| def _fingerprint_dimension(self) -> int: | |
| if self.fingerprint_type == "map4": | |
| return self.map4_dim | |
| return self.n_bits | |
| def _compute_ecfp(self, mols): | |
| fingerprints = np.zeros((len(mols), self.n_bits), dtype=np.float32) | |
| for idx, mol in enumerate(mols): | |
| if self.use_counts: | |
| fp = AllChem.GetMorganFingerprint(mol, self.radius) | |
| arr = np.zeros(self.n_bits, dtype=np.float32) | |
| for bit, value in fp.GetNonzeroElements().items(): | |
| arr[bit % self.n_bits] += value | |
| else: | |
| bitvect = AllChem.GetMorganFingerprintAsBitVect( | |
| mol, | |
| self.radius, | |
| nBits=self.n_bits, | |
| ) | |
| arr = np.zeros(self.n_bits, dtype=np.float32) | |
| DataStructs.ConvertToNumpyArray(bitvect, arr) | |
| fingerprints[idx] = arr | |
| return fingerprints | |
| def _compute_map4(self, mols): | |
| if MAP4Calculator is None: | |
| raise ImportError( | |
| "MAP4 fingerprint requested but the `map4` package is not installed. " | |
| "Install it via `pip install map4` or switch features.type to 'ecfp'." | |
| ) | |
| calc = MAP4Calculator(dimensions=self.map4_dim) | |
| fingerprints = np.zeros((len(mols), self.map4_dim), dtype=np.float32) | |
| for idx, mol in enumerate(mols): | |
| vec = np.array(calc.calculate(mol), dtype=np.float32) | |
| fingerprints[idx] = vec | |
| return fingerprints | |