antoniaebner's picture
add config usage
511fdc4
"""
This files includes a predict function for the Tox21.
As an input it takes a list of SMILES and it outputs a nested dictionary with
SMILES and target names as keys.
"""
# ---------------------------------------------------------------------------------------
# Dependencies
import os
import json
from collections import defaultdict
import numpy as np
from src.model import Tox21TabPFN
from src.data import create_descriptors
from src.utils import load_pickle, KNOWN_DESCR, normalize_config
# ---------------------------------------------------------------------------------------
from tqdm import tqdm
CONFIG_FILE = "config/config.json"
def predict(
smiles_list: list[str], default_prediction=0.5
) -> dict[str, dict[str, float]]:
"""Applies the classifier to a list of SMILES strings. Returns prediction=0.0 for
any molecule that could not be cleaned.
Args:
smiles_list (list[str]): list of SMILES strings
Returns:
dict: nested prediction dictionary, following {'<smiles>': {'<target>': <pred>}}
"""
print(f"Received {len(smiles_list)} SMILES strings")
with open(CONFIG_FILE, "r") as f:
config = json.load(f)
config = normalize_config(config)
# preprocessing pipeline
scaler_path = os.path.join(config["data_folder"], "scaler.pkl")
ecdfs_path = os.path.join(config["data_folder"], "ecdfs.pkl")
feature_selection_path = os.path.join(config["data_folder"], "feat_selection.npz")
feature_order_path = os.path.join(config["data_folder"], "feature_order.npy")
scaler = load_pickle(scaler_path)
ecdfs = load_pickle(ecdfs_path)
feature_selection = np.load(feature_selection_path)
feature_order = np.load(feature_order_path)
print(f"Loaded scaler from: {scaler_path}")
print(f"Loaded ecdfs from: {ecdfs_path}")
print(f"Loaded feature selection from: {feature_selection_path}")
print(f"Loaded feature order from: {feature_order_path}")
features = create_descriptors(
smiles_list,
ecdfs=ecdfs,
feature_selection=feature_selection,
radius=config["ecfp"]["radius"],
fpsize=config["ecfp"]["fpsize"],
)["features"]
features = np.concatenate([features[descr] for descr in KNOWN_DESCR], axis=1)
features = scaler.transform(features[:, feature_order])
print(f"Created descriptors for molecules.")
print(f"{(np.isnan(features).all(axis=1).sum())} molecules removed during cleaning")
is_clean = ~np.isnan(features).all(axis=1)
# setup model
model = Tox21TabPFN(seed=123, device=config["device"])
model.load_model(config["ckpt_path"])
print(f"Loaded model from: {config['ckpt_path']}")
# make predicitons
predictions = defaultdict(dict)
print(f"Create predictions:")
preds = []
for target in tqdm(model.tasks):
preds = np.empty_like(is_clean, dtype=float)
preds[~is_clean] = default_prediction
preds[is_clean] = model.predict(target, features[is_clean])
for smiles, pred in zip(smiles_list, preds):
predictions[smiles][target] = float(pred)
return predictions