File size: 2,565 Bytes
444d15c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
This files includes a RF model for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""

# ---------------------------------------------------------------------------------------
# Dependencies
import joblib
import numpy as np
from sklearn.ensemble import RandomForestClassifier

from .utils import TASKS


# ---------------------------------------------------------------------------------------
class Tox21RFClassifier:
    """A random forest classifier that assigns a toxicity score to a given SMILES string."""

    def __init__(self, seed: int = 42, config: dict = None):
        """Initialize a random forest classifier for each of the 12 Tox21 tasks.

        Args:
            seed (int, optional): seed for RF to ensure reproducibility. Defaults to 42.
        """
        self.tasks = TASKS

        self.models = {
            task: RandomForestClassifier(
                random_state=seed,
                n_jobs=8,
                **({"n_estimators": 1000} if config is None else config[task]),
            )
            for task in self.tasks
        }

    def load(self, path: str) -> None:
        """Load model from filepath

        Args:
            path (str): filepath to model checkpoint
        """
        self.models = joblib.load(path)

    def save(self, path: str) -> None:
        """Save model to filepath

        Args:
            path (str): filepath to model checkpoint
        """
        joblib.dump(self.models, path)

    def fit(self, task: str, X: np.ndarray, y: np.ndarray) -> None:
        """Train the random forest for a given task

        Args:
            task (str): task to train
            X (np.ndarray): training features
            y (np.ndarray): training labels
        """
        assert task in self.tasks, f"Unknown task: {task}"
        _X, _y = X.copy(), y.copy()
        self.models[task].fit(_X, _y)

    def predict(self, task: str, X: np.ndarray) -> np.ndarray:
        """Predicts labels for a given Tox21 target using molecule features

        Args:
            task (str): the Tox21 target to predict for
            X (np.ndarray): molecule features used for prediction

        Returns:
            np.ndarray: predicted probability for positive class
        """
        assert task in self.tasks, f"Unknown task: {task}"
        assert (
            len(X.shape) == 2
        ), f"Function expects 2D np.array. Current shape: {X.shape}"
        _X = X.copy()
        return self.models[task].predict_proba(_X)[:, 1]