File size: 4,931 Bytes
94b1553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import annotations

from pathlib import Path
from typing import Dict, Sequence

import numpy as np
import pandas as pd
from rdkit import DataStructs
from rdkit.Chem import AllChem

from .constants import CANONICAL_SMILES_COLUMN
from .preprocess import MoleculeBatch, filter_dataframe_by_mask, standardize_smiles

try:
    from map4 import MAP4Calculator  # type: ignore
except Exception:  # pragma: no cover - optional dependency
    MAP4Calculator = None


class FingerprintFeaturizer:
    """Compute molecular fingerprints with optional caching."""

    def __init__(self, feature_config: Dict):
        self.config = feature_config
        self.fingerprint_type = feature_config.get("type", "ecfp").lower()
        self.radius = feature_config.get("radius", 2)
        self.n_bits = feature_config.get("n_bits", 1024)
        self.map4_dim = feature_config.get("map4_dim", 1024)
        self.use_counts = feature_config.get("use_counts", False)
        cache_dir = feature_config.get("cache_dir")
        self.cache_dir = Path(cache_dir) if cache_dir else None
        if self.cache_dir:
            self.cache_dir.mkdir(parents=True, exist_ok=True)

    def featurize_dataframe(self, df: pd.DataFrame, split_name: str):
        cache_payload = self._load_cache(split_name)
        if cache_payload is not None:
            mask = cache_payload["mask"]
            canonical_smiles = cache_payload["canonical_smiles"].tolist()
            features = cache_payload["features"]
            clean_df = filter_dataframe_by_mask(df, mask, canonical_smiles)
            return clean_df, features

        batch = standardize_smiles(df["smiles"].tolist())
        clean_df = filter_dataframe_by_mask(df, batch.mask, batch.canonical_smiles)
        features = self._compute_fingerprints(batch.mols)

        self._write_cache(split_name, batch.mask, batch.canonical_smiles, features)
        return clean_df, features

    def featurize_smiles(self, smiles: Sequence[str]) -> tuple[MoleculeBatch, np.ndarray]:
        batch = standardize_smiles(smiles)
        features = self._compute_fingerprints(batch.mols)
        return batch, features

    def _cache_path(self, split_name: str) -> Path | None:
        if self.cache_dir is None:
            return None
        return self.cache_dir / f"{split_name}_{self.fingerprint_type}.npz"

    def _load_cache(self, split_name: str):
        cache_path = self._cache_path(split_name)
        if cache_path is None or not cache_path.exists():
            return None
        return np.load(cache_path, allow_pickle=True)

    def _write_cache(self, split_name: str, mask, canonical_smiles, features):
        cache_path = self._cache_path(split_name)
        if cache_path is None:
            return
        np.savez(
            cache_path,
            mask=mask,
            canonical_smiles=np.array(canonical_smiles, dtype=object),
            features=features,
        )

    def _compute_fingerprints(self, mols):
        if not mols:
            dim = self._fingerprint_dimension()
            return np.zeros((0, dim), dtype=np.float32)

        if self.fingerprint_type == "ecfp":
            return self._compute_ecfp(mols)
        if self.fingerprint_type == "map4":
            return self._compute_map4(mols)
        raise ValueError(f"Unsupported fingerprint type: {self.fingerprint_type}")

    def _fingerprint_dimension(self) -> int:
        if self.fingerprint_type == "map4":
            return self.map4_dim
        return self.n_bits

    def _compute_ecfp(self, mols):
        fingerprints = np.zeros((len(mols), self.n_bits), dtype=np.float32)
        for idx, mol in enumerate(mols):
            if self.use_counts:
                fp = AllChem.GetMorganFingerprint(mol, self.radius)
                arr = np.zeros(self.n_bits, dtype=np.float32)
                for bit, value in fp.GetNonzeroElements().items():
                    arr[bit % self.n_bits] += value
            else:
                bitvect = AllChem.GetMorganFingerprintAsBitVect(
                    mol,
                    self.radius,
                    nBits=self.n_bits,
                )
                arr = np.zeros(self.n_bits, dtype=np.float32)
                DataStructs.ConvertToNumpyArray(bitvect, arr)
            fingerprints[idx] = arr
        return fingerprints

    def _compute_map4(self, mols):
        if MAP4Calculator is None:
            raise ImportError(
                "MAP4 fingerprint requested but the `map4` package is not installed. "
                "Install it via `pip install map4` or switch features.type to 'ecfp'."
            )

        calc = MAP4Calculator(dimensions=self.map4_dim)
        fingerprints = np.zeros((len(mols), self.map4_dim), dtype=np.float32)
        for idx, mol in enumerate(mols):
            vec = np.array(calc.calculate(mol), dtype=np.float32)
            fingerprints[idx] = vec
        return fingerprints