""" This files includes a TabPFN model for Tox21. As an input it takes a list of SMILES and it outputs a nested dictionary with SMILES and target names as keys. """ # --------------------------------------------------------------------------------------- # Dependencies import os import numpy as np from tabpfn import TabPFNClassifier from tabpfn.model_loading import load_fitted_tabpfn_model, save_fitted_tabpfn_model from .utils import TASKS # --------------------------------------------------------------------------------------- class Tox21TabPFN: """A TabPFN classifier that assigns a toxicity score to a given SMILES string.""" def __init__(self, seed: int = 42, device: str = "cpu"): """Initialize an TabPFN classifier for each of the 12 Tox21 tasks. Args: seed (int, optional): seed for TabPFN to ensure reproducibility. Defaults to 42. """ self.tasks = TASKS self.model = { task: TabPFNClassifier(random_state=seed, device=device) for task in self.tasks } self.device = device def load_model(self, path: str) -> None: """Loads the model from a given folder path. In the folder, each task should have a folder containing the files "scaler.pkl" and "ckpt.tabpfn_fit". Args: path (str): folder path to model checkpoints """ for task in self.tasks: self.model[task] = load_fitted_tabpfn_model( os.path.join(path, task, "ckpt.tabpfn_fit"), device=self.device ) print(self.device) def save_model(self, path: str) -> None: """Saves the model to a given path Args: path (str): path to save model to """ for task in self.tasks: model_path = os.path.join(path, task, f"ckpt.tabpfn_fit") save_fitted_tabpfn_model(self.model[task], model_path) def fit(self, task: str, input_features: np.ndarray, labels: np.ndarray) -> None: """Train TabPFN for a given task Args: task (str): task to train input_features (np.ndarray): training features labels (np.ndarray): training labels """ assert task in self.tasks, f"Unknown task: {task}" assert ( len(labels.shape) == 1 ), "2-dim labels passed. Function only accepts 1-dim labels." self.model[task].fit(input_features, labels) def predict(self, task: str, features: np.ndarray) -> np.ndarray: """Predicts labels for a given Tox21 target using molecule features Args: task (str): the Tox21 target to predict for features (np.ndarray): molecule features used for prediction Returns: np.ndarray: predicted probability for positive class """ assert task in self.tasks, f"Unknown task: {task}" assert ( len(features.shape) == 2 ), f"Function expects 2D np.array. Current shape: {features.shape}" preds = self.model[task].predict_proba(features) return preds[:, 1]