tox21-classifier / predict.py
sk16er's picture
Update predict.py
3700a25 verified
import torch
import os
import numpy as np
import logging
import glob
import pickle
from rdkit import Chem
from model import DMPNN, MolGraph, BatchMolGraph
from data import TASKS, get_global_features, SCALER_FILE
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _sanitize_array(x):
"""
Ensure JSON-safe numeric output.
NaN -> 0.0, +inf -> 1.0, -inf -> 0.0
"""
return np.nan_to_num(x, nan=0.0, posinf=1.0, neginf=0.0)
class Tox21Predictor:
def __init__(self, model_dir="."):
self.models = []
self.scaler = None
# Load scaler
if os.path.exists(SCALER_FILE):
try:
with open(SCALER_FILE, "rb") as f:
self.scaler = pickle.load(f)
logger.info("Scaler loaded.")
except Exception as e:
logger.error(f"Error loading scaler: {e}")
# Load ensemble checkpoints
model_paths = glob.glob(os.path.join("checkpoints", "model_seed_*.pt"))
for path in model_paths:
try:
# Fix: Using n_tasks to match the DMPNN __init__ in model.py
model = DMPNN(
hidden_size=300,
depth=3,
global_feats_size=217,
n_tasks=len(TASKS),
)
model.load_state_dict(torch.load(path, map_location=DEVICE))
model.to(DEVICE)
model.eval()
self.models.append(model)
logger.info(f"Loaded checkpoint: {path}")
except Exception as e:
logger.error(f"Failed to load {path}: {e}")
def predict(self, smiles_list):
# Default results: mapping SMILES to a dict of tasks set to 0.0
results = {s: {task: 0.0 for task in TASKS} for s in smiles_list}
if not self.models:
logger.warning("No models loaded. Returning default zeros.")
return results
valid_indices = []
valid_smiles = []
mol_graphs = []
global_features = []
for i, s in enumerate(smiles_list):
try:
mol = Chem.MolFromSmiles(s)
if mol is None:
continue
# Featurization
feats = get_global_features(mol)
feats = _sanitize_array(feats)
valid_indices.append(i)
valid_smiles.append(s)
mol_graphs.append(MolGraph(s))
global_features.append(feats)
except Exception as e:
logger.debug(f"SMILES error for {s}: {e}")
continue
if not mol_graphs:
return results
# Scale features
global_features = np.stack(global_features)
if self.scaler is not None:
try:
global_features = self.scaler.transform(global_features)
except Exception as e:
logger.error(f"Scaling error: {e}")
global_features = _sanitize_array(global_features)
global_features_tensor = torch.tensor(
global_features, dtype=torch.float32, device=DEVICE
)
# Batch graph processing
batch_graph = BatchMolGraph(mol_graphs)
ensemble_preds = []
# Inference
with torch.no_grad():
for model in self.models:
logits = model(batch_graph, global_features_tensor)
probs = torch.sigmoid(logits).cpu().numpy()
ensemble_preds.append(_sanitize_array(probs))
if not ensemble_preds:
return results
# Average ensemble predictions
avg_preds = np.mean(np.stack(ensemble_preds, axis=0), axis=0)
avg_preds = _sanitize_array(avg_preds)
# Map back to results dictionary
for batch_idx, original_idx in enumerate(valid_indices):
s = smiles_list[original_idx]
for t_idx, task in enumerate(TASKS):
results[s][task] = float(avg_preds[batch_idx, t_idx])
return results
# ============================
# HF LEADERBOARD ENTRY POINT
# ============================
_predictor = None
def predict(smiles_list):
"""
Leaderboard Entry Point: Returns a dictionary with "predictions" key
containing the raw dictionary results from Tox21Predictor.
"""
global _predictor
if _predictor is None:
_predictor = Tox21Predictor()
raw_results = _predictor.predict(smiles_list)
return {"predictions": raw_results}