Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import Dataset | |
| from rdkit import Chem | |
| from rdkit.Chem import Descriptors | |
| from rdkit.Chem.Scaffolds import MurckoScaffold | |
| from collections import defaultdict | |
| import logging | |
| import numpy as np | |
| import pickle | |
| from sklearn.preprocessing import StandardScaler | |
| from datasets import load_dataset # Added for leaderboard compliance | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| SCALER_FILE = "scaler.pkl" | |
| # data.py | |
| TASKS = [ | |
| "NR-AhR", | |
| "NR-AR", | |
| "NR-AR-LBD", | |
| "NR-Aromatase", | |
| "NR-ER", | |
| "NR-ER-LBD", | |
| "NR-PPAR-gamma", | |
| "SR-ARE", | |
| "SR-ATAD5", | |
| "SR-HSE", | |
| "SR-MMP", | |
| "SR-p53", | |
| ] | |
| class Tox21Dataset(Dataset): | |
| def __init__(self, data): | |
| self.data = data | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |
| def get_global_features(mol): | |
| if mol is None: | |
| return np.zeros(217, dtype=np.float32) # Match the trained model's 217 | |
| # Get descriptors available in the current environment | |
| res = [] | |
| for name, func in Descriptors.descList: | |
| try: | |
| val = func(mol) | |
| res.append(val) | |
| except: | |
| res.append(0.0) | |
| # PAD OR TRUNCATE TO EXACTLY 217 | |
| # This ensures compatibility with your specific model checkpoints | |
| if len(res) < 217: | |
| res.extend([0.0] * (217 - len(res))) | |
| return np.array(res[:217], dtype=np.float32) | |
| def scaffold_split(smiles_list, train_frac=0.9): | |
| """Standard scaffold split for molecular data.""" | |
| scaffolds = defaultdict(list) | |
| for i, smiles in enumerate(smiles_list): | |
| mol = Chem.MolFromSmiles(smiles) | |
| scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=True) if mol else "" | |
| scaffolds[scaffold].append(i) | |
| scaffold_sets = sorted(list(scaffolds.values()), key=len, reverse=True) | |
| train_indices = [] | |
| train_cutoff = train_frac * len(smiles_list) | |
| for scaffold_set in scaffold_sets: | |
| if len(train_indices) + len(scaffold_set) > train_cutoff: | |
| break | |
| train_indices.extend(scaffold_set) | |
| all_indices = set(range(len(smiles_list))) | |
| val_indices = list(all_indices - set(train_indices)) | |
| return train_indices, val_indices | |
| def load_data(): | |
| """ | |
| Loads Tox21 from Hugging Face Hub instead of local CSV. | |
| Matches leaderboard requirements for training verification. | |
| """ | |
| logger.info("Fetching Tox21 dataset from Hugging Face Hub (ml-jku/tox21)...") | |
| dataset = load_dataset("ml-jku/tox21") | |
| # Merge splits for consistent processing or use them directly | |
| # Here we process the official training set | |
| df = dataset['train'].to_pandas() | |
| # Pre-process: ensure SMILES are valid | |
| valid_mask = df['smiles'].apply(lambda s: Chem.MolFromSmiles(s) is not None) | |
| df = df[valid_mask].reset_index(drop=True) | |
| logger.info(f"Computing descriptors for {len(df)} molecules...") | |
| all_global_features = [] | |
| for s in df['smiles']: | |
| mol = Chem.MolFromSmiles(s) | |
| all_global_features.append(get_global_features(mol)) | |
| all_global_features = np.array(all_global_features) | |
| # Scaffold Split | |
| train_idx, val_idx = scaffold_split(df['smiles'].tolist()) | |
| # Fit/Apply Scaler | |
| scaler = StandardScaler() | |
| scaler.fit(all_global_features[train_idx]) | |
| with open(SCALER_FILE, 'wb') as f: | |
| pickle.dump(scaler, f) | |
| all_global_features_scaled = scaler.transform(all_global_features) | |
| def format_subset(indices): | |
| data_list = [] | |
| for original_idx in indices: | |
| row = df.iloc[original_idx] | |
| # Convert labels: leaderboard uses NaNs for missing data | |
| labels = row[TASKS].values.astype(np.float32) | |
| labels = np.nan_to_num(labels, nan=-1.0) # -1 signals missing for the loss function | |
| data_list.append({ | |
| 'smiles': row['smiles'], | |
| 'labels': torch.tensor(labels, dtype=torch.float32), | |
| 'global_features': torch.tensor(all_global_features_scaled[original_idx], dtype=torch.float32), | |
| 'mol_id': f"mol_{original_idx}" | |
| }) | |
| return data_list | |
| return format_subset(train_idx), format_subset(val_idx), scaler |