Spaces:
No application file
No application file
| """ | |
| This files includes a TabPFN model for Tox21. | |
| As an input it takes a list of SMILES and it outputs a nested dictionary with | |
| SMILES and target names as keys. | |
| """ | |
| # --------------------------------------------------------------------------------------- | |
| # Dependencies | |
| import os | |
| import numpy as np | |
| from tabpfn import TabPFNClassifier | |
| from tabpfn.model_loading import load_fitted_tabpfn_model, save_fitted_tabpfn_model | |
| from .utils import TASKS | |
| # --------------------------------------------------------------------------------------- | |
| class Tox21TabPFN: | |
| """A TabPFN classifier that assigns a toxicity score to a given SMILES string.""" | |
| def __init__(self, seed: int = 42, device: str = "cpu"): | |
| """Initialize an TabPFN classifier for each of the 12 Tox21 tasks. | |
| Args: | |
| seed (int, optional): seed for TabPFN to ensure reproducibility. Defaults to 42. | |
| """ | |
| self.tasks = TASKS | |
| self.model = { | |
| task: TabPFNClassifier(random_state=seed, device=device) | |
| for task in self.tasks | |
| } | |
| self.device = device | |
| def load_model(self, path: str) -> None: | |
| """Loads the model from a given folder path. In the folder, each task should | |
| have a folder containing the files "scaler.pkl" and "ckpt.tabpfn_fit". | |
| Args: | |
| path (str): folder path to model checkpoints | |
| """ | |
| for task in self.tasks: | |
| self.model[task] = load_fitted_tabpfn_model( | |
| os.path.join(path, task, "ckpt.tabpfn_fit"), device=self.device | |
| ) | |
| print(self.device) | |
| def save_model(self, path: str) -> None: | |
| """Saves the model to a given path | |
| Args: | |
| path (str): path to save model to | |
| """ | |
| for task in self.tasks: | |
| model_path = os.path.join(path, task, f"ckpt.tabpfn_fit") | |
| save_fitted_tabpfn_model(self.model[task], model_path) | |
| def fit(self, task: str, input_features: np.ndarray, labels: np.ndarray) -> None: | |
| """Train TabPFN for a given task | |
| Args: | |
| task (str): task to train | |
| input_features (np.ndarray): training features | |
| labels (np.ndarray): training labels | |
| """ | |
| assert task in self.tasks, f"Unknown task: {task}" | |
| assert ( | |
| len(labels.shape) == 1 | |
| ), "2-dim labels passed. Function only accepts 1-dim labels." | |
| self.model[task].fit(input_features, labels) | |
| def predict(self, task: str, features: np.ndarray) -> np.ndarray: | |
| """Predicts labels for a given Tox21 target using molecule features | |
| Args: | |
| task (str): the Tox21 target to predict for | |
| features (np.ndarray): molecule features used for prediction | |
| Returns: | |
| np.ndarray: predicted probability for positive class | |
| """ | |
| assert task in self.tasks, f"Unknown task: {task}" | |
| assert ( | |
| len(features.shape) == 2 | |
| ), f"Function expects 2D np.array. Current shape: {features.shape}" | |
| preds = self.model[task].predict_proba(features) | |
| return preds[:, 1] | |