Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| import numpy as np | |
| import logging | |
| import glob | |
| import pickle | |
| from rdkit import Chem | |
| from model import DMPNN, MolGraph, BatchMolGraph | |
| from data import TASKS, get_global_features, SCALER_FILE | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def _sanitize_array(x): | |
| """ | |
| Ensure JSON-safe numeric output. | |
| NaN -> 0.0, +inf -> 1.0, -inf -> 0.0 | |
| """ | |
| return np.nan_to_num(x, nan=0.0, posinf=1.0, neginf=0.0) | |
| class Tox21Predictor: | |
| def __init__(self, model_dir="."): | |
| self.models = [] | |
| self.scaler = None | |
| # Load scaler | |
| if os.path.exists(SCALER_FILE): | |
| try: | |
| with open(SCALER_FILE, "rb") as f: | |
| self.scaler = pickle.load(f) | |
| logger.info("Scaler loaded.") | |
| except Exception as e: | |
| logger.error(f"Error loading scaler: {e}") | |
| # Load ensemble checkpoints | |
| model_paths = glob.glob(os.path.join("checkpoints", "model_seed_*.pt")) | |
| for path in model_paths: | |
| try: | |
| # Fix: Using n_tasks to match the DMPNN __init__ in model.py | |
| model = DMPNN( | |
| hidden_size=300, | |
| depth=3, | |
| global_feats_size=217, | |
| n_tasks=len(TASKS), | |
| ) | |
| model.load_state_dict(torch.load(path, map_location=DEVICE)) | |
| model.to(DEVICE) | |
| model.eval() | |
| self.models.append(model) | |
| logger.info(f"Loaded checkpoint: {path}") | |
| except Exception as e: | |
| logger.error(f"Failed to load {path}: {e}") | |
| def predict(self, smiles_list): | |
| # Default results: mapping SMILES to a dict of tasks set to 0.0 | |
| results = {s: {task: 0.0 for task in TASKS} for s in smiles_list} | |
| if not self.models: | |
| logger.warning("No models loaded. Returning default zeros.") | |
| return results | |
| valid_indices = [] | |
| valid_smiles = [] | |
| mol_graphs = [] | |
| global_features = [] | |
| for i, s in enumerate(smiles_list): | |
| try: | |
| mol = Chem.MolFromSmiles(s) | |
| if mol is None: | |
| continue | |
| # Featurization | |
| feats = get_global_features(mol) | |
| feats = _sanitize_array(feats) | |
| valid_indices.append(i) | |
| valid_smiles.append(s) | |
| mol_graphs.append(MolGraph(s)) | |
| global_features.append(feats) | |
| except Exception as e: | |
| logger.debug(f"SMILES error for {s}: {e}") | |
| continue | |
| if not mol_graphs: | |
| return results | |
| # Scale features | |
| global_features = np.stack(global_features) | |
| if self.scaler is not None: | |
| try: | |
| global_features = self.scaler.transform(global_features) | |
| except Exception as e: | |
| logger.error(f"Scaling error: {e}") | |
| global_features = _sanitize_array(global_features) | |
| global_features_tensor = torch.tensor( | |
| global_features, dtype=torch.float32, device=DEVICE | |
| ) | |
| # Batch graph processing | |
| batch_graph = BatchMolGraph(mol_graphs) | |
| ensemble_preds = [] | |
| # Inference | |
| with torch.no_grad(): | |
| for model in self.models: | |
| logits = model(batch_graph, global_features_tensor) | |
| probs = torch.sigmoid(logits).cpu().numpy() | |
| ensemble_preds.append(_sanitize_array(probs)) | |
| if not ensemble_preds: | |
| return results | |
| # Average ensemble predictions | |
| avg_preds = np.mean(np.stack(ensemble_preds, axis=0), axis=0) | |
| avg_preds = _sanitize_array(avg_preds) | |
| # Map back to results dictionary | |
| for batch_idx, original_idx in enumerate(valid_indices): | |
| s = smiles_list[original_idx] | |
| for t_idx, task in enumerate(TASKS): | |
| results[s][task] = float(avg_preds[batch_idx, t_idx]) | |
| return results | |
| # ============================ | |
| # HF LEADERBOARD ENTRY POINT | |
| # ============================ | |
| _predictor = None | |
| def predict(smiles_list): | |
| """ | |
| Leaderboard Entry Point: Returns a dictionary with "predictions" key | |
| containing the raw dictionary results from Tox21Predictor. | |
| """ | |
| global _predictor | |
| if _predictor is None: | |
| _predictor = Tox21Predictor() | |
| raw_results = _predictor.predict(smiles_list) | |
| return {"predictions": raw_results} |