import os import pandas as pd import torch from torch.utils.data import Dataset from rdkit import Chem from rdkit.Chem import Descriptors from rdkit.Chem.Scaffolds import MurckoScaffold from collections import defaultdict import logging import numpy as np import pickle from sklearn.preprocessing import StandardScaler from datasets import load_dataset # Added for leaderboard compliance # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) SCALER_FILE = "scaler.pkl" # data.py TASKS = [ "NR-AhR", "NR-AR", "NR-AR-LBD", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53", ] class Tox21Dataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def get_global_features(mol): if mol is None: return np.zeros(217, dtype=np.float32) # Match the trained model's 217 # Get descriptors available in the current environment res = [] for name, func in Descriptors.descList: try: val = func(mol) res.append(val) except: res.append(0.0) # PAD OR TRUNCATE TO EXACTLY 217 # This ensures compatibility with your specific model checkpoints if len(res) < 217: res.extend([0.0] * (217 - len(res))) return np.array(res[:217], dtype=np.float32) def scaffold_split(smiles_list, train_frac=0.9): """Standard scaffold split for molecular data.""" scaffolds = defaultdict(list) for i, smiles in enumerate(smiles_list): mol = Chem.MolFromSmiles(smiles) scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=True) if mol else "" scaffolds[scaffold].append(i) scaffold_sets = sorted(list(scaffolds.values()), key=len, reverse=True) train_indices = [] train_cutoff = train_frac * len(smiles_list) for scaffold_set in scaffold_sets: if len(train_indices) + len(scaffold_set) > train_cutoff: break train_indices.extend(scaffold_set) all_indices = set(range(len(smiles_list))) val_indices = list(all_indices - set(train_indices)) return train_indices, val_indices def load_data(): """ Loads Tox21 from Hugging Face Hub instead of local CSV. Matches leaderboard requirements for training verification. """ logger.info("Fetching Tox21 dataset from Hugging Face Hub (ml-jku/tox21)...") dataset = load_dataset("ml-jku/tox21") # Merge splits for consistent processing or use them directly # Here we process the official training set df = dataset['train'].to_pandas() # Pre-process: ensure SMILES are valid valid_mask = df['smiles'].apply(lambda s: Chem.MolFromSmiles(s) is not None) df = df[valid_mask].reset_index(drop=True) logger.info(f"Computing descriptors for {len(df)} molecules...") all_global_features = [] for s in df['smiles']: mol = Chem.MolFromSmiles(s) all_global_features.append(get_global_features(mol)) all_global_features = np.array(all_global_features) # Scaffold Split train_idx, val_idx = scaffold_split(df['smiles'].tolist()) # Fit/Apply Scaler scaler = StandardScaler() scaler.fit(all_global_features[train_idx]) with open(SCALER_FILE, 'wb') as f: pickle.dump(scaler, f) all_global_features_scaled = scaler.transform(all_global_features) def format_subset(indices): data_list = [] for original_idx in indices: row = df.iloc[original_idx] # Convert labels: leaderboard uses NaNs for missing data labels = row[TASKS].values.astype(np.float32) labels = np.nan_to_num(labels, nan=-1.0) # -1 signals missing for the loss function data_list.append({ 'smiles': row['smiles'], 'labels': torch.tensor(labels, dtype=torch.float32), 'global_features': torch.tensor(all_global_features_scaled[original_idx], dtype=torch.float32), 'mol_id': f"mol_{original_idx}" }) return data_list return format_subset(train_idx), format_subset(val_idx), scaler