tox21-classifier / data.py
sk16er's picture
Update data.py
a979357 verified
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict
import logging
import numpy as np
import pickle
from sklearn.preprocessing import StandardScaler
from datasets import load_dataset # Added for leaderboard compliance
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
SCALER_FILE = "scaler.pkl"
# data.py
TASKS = [
"NR-AhR",
"NR-AR",
"NR-AR-LBD",
"NR-Aromatase",
"NR-ER",
"NR-ER-LBD",
"NR-PPAR-gamma",
"SR-ARE",
"SR-ATAD5",
"SR-HSE",
"SR-MMP",
"SR-p53",
]
class Tox21Dataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def get_global_features(mol):
if mol is None:
return np.zeros(217, dtype=np.float32) # Match the trained model's 217
# Get descriptors available in the current environment
res = []
for name, func in Descriptors.descList:
try:
val = func(mol)
res.append(val)
except:
res.append(0.0)
# PAD OR TRUNCATE TO EXACTLY 217
# This ensures compatibility with your specific model checkpoints
if len(res) < 217:
res.extend([0.0] * (217 - len(res)))
return np.array(res[:217], dtype=np.float32)
def scaffold_split(smiles_list, train_frac=0.9):
"""Standard scaffold split for molecular data."""
scaffolds = defaultdict(list)
for i, smiles in enumerate(smiles_list):
mol = Chem.MolFromSmiles(smiles)
scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=True) if mol else ""
scaffolds[scaffold].append(i)
scaffold_sets = sorted(list(scaffolds.values()), key=len, reverse=True)
train_indices = []
train_cutoff = train_frac * len(smiles_list)
for scaffold_set in scaffold_sets:
if len(train_indices) + len(scaffold_set) > train_cutoff:
break
train_indices.extend(scaffold_set)
all_indices = set(range(len(smiles_list)))
val_indices = list(all_indices - set(train_indices))
return train_indices, val_indices
def load_data():
"""
Loads Tox21 from Hugging Face Hub instead of local CSV.
Matches leaderboard requirements for training verification.
"""
logger.info("Fetching Tox21 dataset from Hugging Face Hub (ml-jku/tox21)...")
dataset = load_dataset("ml-jku/tox21")
# Merge splits for consistent processing or use them directly
# Here we process the official training set
df = dataset['train'].to_pandas()
# Pre-process: ensure SMILES are valid
valid_mask = df['smiles'].apply(lambda s: Chem.MolFromSmiles(s) is not None)
df = df[valid_mask].reset_index(drop=True)
logger.info(f"Computing descriptors for {len(df)} molecules...")
all_global_features = []
for s in df['smiles']:
mol = Chem.MolFromSmiles(s)
all_global_features.append(get_global_features(mol))
all_global_features = np.array(all_global_features)
# Scaffold Split
train_idx, val_idx = scaffold_split(df['smiles'].tolist())
# Fit/Apply Scaler
scaler = StandardScaler()
scaler.fit(all_global_features[train_idx])
with open(SCALER_FILE, 'wb') as f:
pickle.dump(scaler, f)
all_global_features_scaled = scaler.transform(all_global_features)
def format_subset(indices):
data_list = []
for original_idx in indices:
row = df.iloc[original_idx]
# Convert labels: leaderboard uses NaNs for missing data
labels = row[TASKS].values.astype(np.float32)
labels = np.nan_to_num(labels, nan=-1.0) # -1 signals missing for the loss function
data_list.append({
'smiles': row['smiles'],
'labels': torch.tensor(labels, dtype=torch.float32),
'global_features': torch.tensor(all_global_features_scaled[original_idx], dtype=torch.float32),
'mol_id': f"mol_{original_idx}"
})
return data_list
return format_subset(train_idx), format_subset(val_idx), scaler