import torch import os import numpy as np import logging import glob import pickle from rdkit import Chem from model import DMPNN, MolGraph, BatchMolGraph from data import TASKS, get_global_features, SCALER_FILE # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def _sanitize_array(x): """ Ensure JSON-safe numeric output. NaN -> 0.0, +inf -> 1.0, -inf -> 0.0 """ return np.nan_to_num(x, nan=0.0, posinf=1.0, neginf=0.0) class Tox21Predictor: def __init__(self, model_dir="."): self.models = [] self.scaler = None # Load scaler if os.path.exists(SCALER_FILE): try: with open(SCALER_FILE, "rb") as f: self.scaler = pickle.load(f) logger.info("Scaler loaded.") except Exception as e: logger.error(f"Error loading scaler: {e}") # Load ensemble checkpoints model_paths = glob.glob(os.path.join("checkpoints", "model_seed_*.pt")) for path in model_paths: try: # Fix: Using n_tasks to match the DMPNN __init__ in model.py model = DMPNN( hidden_size=300, depth=3, global_feats_size=217, n_tasks=len(TASKS), ) model.load_state_dict(torch.load(path, map_location=DEVICE)) model.to(DEVICE) model.eval() self.models.append(model) logger.info(f"Loaded checkpoint: {path}") except Exception as e: logger.error(f"Failed to load {path}: {e}") def predict(self, smiles_list): # Default results: mapping SMILES to a dict of tasks set to 0.0 results = {s: {task: 0.0 for task in TASKS} for s in smiles_list} if not self.models: logger.warning("No models loaded. Returning default zeros.") return results valid_indices = [] valid_smiles = [] mol_graphs = [] global_features = [] for i, s in enumerate(smiles_list): try: mol = Chem.MolFromSmiles(s) if mol is None: continue # Featurization feats = get_global_features(mol) feats = _sanitize_array(feats) valid_indices.append(i) valid_smiles.append(s) mol_graphs.append(MolGraph(s)) global_features.append(feats) except Exception as e: logger.debug(f"SMILES error for {s}: {e}") continue if not mol_graphs: return results # Scale features global_features = np.stack(global_features) if self.scaler is not None: try: global_features = self.scaler.transform(global_features) except Exception as e: logger.error(f"Scaling error: {e}") global_features = _sanitize_array(global_features) global_features_tensor = torch.tensor( global_features, dtype=torch.float32, device=DEVICE ) # Batch graph processing batch_graph = BatchMolGraph(mol_graphs) ensemble_preds = [] # Inference with torch.no_grad(): for model in self.models: logits = model(batch_graph, global_features_tensor) probs = torch.sigmoid(logits).cpu().numpy() ensemble_preds.append(_sanitize_array(probs)) if not ensemble_preds: return results # Average ensemble predictions avg_preds = np.mean(np.stack(ensemble_preds, axis=0), axis=0) avg_preds = _sanitize_array(avg_preds) # Map back to results dictionary for batch_idx, original_idx in enumerate(valid_indices): s = smiles_list[original_idx] for t_idx, task in enumerate(TASKS): results[s][task] = float(avg_preds[batch_idx, t_idx]) return results # ============================ # HF LEADERBOARD ENTRY POINT # ============================ _predictor = None def predict(smiles_list): """ Leaderboard Entry Point: Returns a dictionary with "predictions" key containing the raw dictionary results from Tox21Predictor. """ global _predictor if _predictor is None: _predictor = Tox21Predictor() raw_results = _predictor.predict(smiles_list) return {"predictions": raw_results}