Spaces:
Running
Running
File size: 4,435 Bytes
e7fccc5 190f8ed e7fccc5 a979357 e7fccc5 a979357 e7fccc5 a979357 e7fccc5 0aa2b45 e7fccc5 0aa2b45 e7fccc5 0aa2b45 e7fccc5 0aa2b45 e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed e7fccc5 190f8ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict
import logging
import numpy as np
import pickle
from sklearn.preprocessing import StandardScaler
from datasets import load_dataset # Added for leaderboard compliance
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
SCALER_FILE = "scaler.pkl"
# data.py
TASKS = [
"NR-AhR",
"NR-AR",
"NR-AR-LBD",
"NR-Aromatase",
"NR-ER",
"NR-ER-LBD",
"NR-PPAR-gamma",
"SR-ARE",
"SR-ATAD5",
"SR-HSE",
"SR-MMP",
"SR-p53",
]
class Tox21Dataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def get_global_features(mol):
if mol is None:
return np.zeros(217, dtype=np.float32) # Match the trained model's 217
# Get descriptors available in the current environment
res = []
for name, func in Descriptors.descList:
try:
val = func(mol)
res.append(val)
except:
res.append(0.0)
# PAD OR TRUNCATE TO EXACTLY 217
# This ensures compatibility with your specific model checkpoints
if len(res) < 217:
res.extend([0.0] * (217 - len(res)))
return np.array(res[:217], dtype=np.float32)
def scaffold_split(smiles_list, train_frac=0.9):
"""Standard scaffold split for molecular data."""
scaffolds = defaultdict(list)
for i, smiles in enumerate(smiles_list):
mol = Chem.MolFromSmiles(smiles)
scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=True) if mol else ""
scaffolds[scaffold].append(i)
scaffold_sets = sorted(list(scaffolds.values()), key=len, reverse=True)
train_indices = []
train_cutoff = train_frac * len(smiles_list)
for scaffold_set in scaffold_sets:
if len(train_indices) + len(scaffold_set) > train_cutoff:
break
train_indices.extend(scaffold_set)
all_indices = set(range(len(smiles_list)))
val_indices = list(all_indices - set(train_indices))
return train_indices, val_indices
def load_data():
"""
Loads Tox21 from Hugging Face Hub instead of local CSV.
Matches leaderboard requirements for training verification.
"""
logger.info("Fetching Tox21 dataset from Hugging Face Hub (ml-jku/tox21)...")
dataset = load_dataset("ml-jku/tox21")
# Merge splits for consistent processing or use them directly
# Here we process the official training set
df = dataset['train'].to_pandas()
# Pre-process: ensure SMILES are valid
valid_mask = df['smiles'].apply(lambda s: Chem.MolFromSmiles(s) is not None)
df = df[valid_mask].reset_index(drop=True)
logger.info(f"Computing descriptors for {len(df)} molecules...")
all_global_features = []
for s in df['smiles']:
mol = Chem.MolFromSmiles(s)
all_global_features.append(get_global_features(mol))
all_global_features = np.array(all_global_features)
# Scaffold Split
train_idx, val_idx = scaffold_split(df['smiles'].tolist())
# Fit/Apply Scaler
scaler = StandardScaler()
scaler.fit(all_global_features[train_idx])
with open(SCALER_FILE, 'wb') as f:
pickle.dump(scaler, f)
all_global_features_scaled = scaler.transform(all_global_features)
def format_subset(indices):
data_list = []
for original_idx in indices:
row = df.iloc[original_idx]
# Convert labels: leaderboard uses NaNs for missing data
labels = row[TASKS].values.astype(np.float32)
labels = np.nan_to_num(labels, nan=-1.0) # -1 signals missing for the loss function
data_list.append({
'smiles': row['smiles'],
'labels': torch.tensor(labels, dtype=torch.float32),
'global_features': torch.tensor(all_global_features_scaled[original_idx], dtype=torch.float32),
'mol_id': f"mol_{original_idx}"
})
return data_list
return format_subset(train_idx), format_subset(val_idx), scaler |