antoniaebner's picture
adapt model to use one scaler for all tasks; debug predict
d790da7
"""
This files includes a TabPFN model for Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""
# ---------------------------------------------------------------------------------------
# Dependencies
import os
import numpy as np
from tabpfn import TabPFNClassifier
from tabpfn.model_loading import load_fitted_tabpfn_model, save_fitted_tabpfn_model
from .utils import TASKS
# ---------------------------------------------------------------------------------------
class Tox21TabPFN:
"""A TabPFN classifier that assigns a toxicity score to a given SMILES string."""
def __init__(self, seed: int = 42, device: str = "cpu"):
"""Initialize an TabPFN classifier for each of the 12 Tox21 tasks.
Args:
seed (int, optional): seed for TabPFN to ensure reproducibility. Defaults to 42.
"""
self.tasks = TASKS
self.model = {
task: TabPFNClassifier(random_state=seed, device=device)
for task in self.tasks
}
self.device = device
def load_model(self, path: str) -> None:
"""Loads the model from a given folder path. In the folder, each task should
have a folder containing the files "scaler.pkl" and "ckpt.tabpfn_fit".
Args:
path (str): folder path to model checkpoints
"""
for task in self.tasks:
self.model[task] = load_fitted_tabpfn_model(
os.path.join(path, task, "ckpt.tabpfn_fit"), device=self.device
)
print(self.device)
def save_model(self, path: str) -> None:
"""Saves the model to a given path
Args:
path (str): path to save model to
"""
for task in self.tasks:
model_path = os.path.join(path, task, f"ckpt.tabpfn_fit")
save_fitted_tabpfn_model(self.model[task], model_path)
def fit(self, task: str, input_features: np.ndarray, labels: np.ndarray) -> None:
"""Train TabPFN for a given task
Args:
task (str): task to train
input_features (np.ndarray): training features
labels (np.ndarray): training labels
"""
assert task in self.tasks, f"Unknown task: {task}"
assert (
len(labels.shape) == 1
), "2-dim labels passed. Function only accepts 1-dim labels."
self.model[task].fit(input_features, labels)
def predict(self, task: str, features: np.ndarray) -> np.ndarray:
"""Predicts labels for a given Tox21 target using molecule features
Args:
task (str): the Tox21 target to predict for
features (np.ndarray): molecule features used for prediction
Returns:
np.ndarray: predicted probability for positive class
"""
assert task in self.tasks, f"Unknown task: {task}"
assert (
len(features.shape) == 2
), f"Function expects 2D np.array. Current shape: {features.shape}"
preds = self.model[task].predict_proba(features)
return preds[:, 1]