File size: 4,435 Bytes
e7fccc5
 
 
 
 
 
 
 
 
 
 
 
190f8ed
e7fccc5
 
 
 
 
 
 
a979357
e7fccc5
a979357
 
 
 
 
 
 
 
 
 
 
 
e7fccc5
 
a979357
e7fccc5
 
 
 
 
 
 
 
 
 
 
 
0aa2b45
e7fccc5
0aa2b45
 
e7fccc5
 
 
0aa2b45
e7fccc5
0aa2b45
 
 
 
 
 
 
e7fccc5
190f8ed
 
e7fccc5
190f8ed
 
 
 
e7fccc5
190f8ed
 
 
e7fccc5
190f8ed
 
 
 
e7fccc5
190f8ed
 
 
e7fccc5
190f8ed
 
 
 
 
 
 
e7fccc5
190f8ed
 
 
e7fccc5
190f8ed
 
 
e7fccc5
190f8ed
 
 
 
 
e7fccc5
190f8ed
e7fccc5
190f8ed
 
e7fccc5
190f8ed
e7fccc5
190f8ed
e7fccc5
 
 
 
 
 
 
 
190f8ed
e7fccc5
190f8ed
e7fccc5
190f8ed
e7fccc5
 
 
 
190f8ed
 
e7fccc5
 
 
190f8ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict
import logging
import numpy as np
import pickle
from sklearn.preprocessing import StandardScaler
from datasets import load_dataset  # Added for leaderboard compliance

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

SCALER_FILE = "scaler.pkl"

# data.py
TASKS = [
    "NR-AhR",
    "NR-AR",
    "NR-AR-LBD",
    "NR-Aromatase",
    "NR-ER",
    "NR-ER-LBD",
    "NR-PPAR-gamma",
    "SR-ARE",
    "SR-ATAD5",
    "SR-HSE",
    "SR-MMP",
    "SR-p53",
]


class Tox21Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def get_global_features(mol):
    if mol is None:
        return np.zeros(217, dtype=np.float32) # Match the trained model's 217
    
    # Get descriptors available in the current environment
    res = []
    for name, func in Descriptors.descList:
        try:
            val = func(mol)
            res.append(val)
        except:
            res.append(0.0)
    
    # PAD OR TRUNCATE TO EXACTLY 217
    # This ensures compatibility with your specific model checkpoints
    if len(res) < 217:
        res.extend([0.0] * (217 - len(res)))
    return np.array(res[:217], dtype=np.float32)

def scaffold_split(smiles_list, train_frac=0.9):
    """Standard scaffold split for molecular data."""
    scaffolds = defaultdict(list)
    for i, smiles in enumerate(smiles_list):
        mol = Chem.MolFromSmiles(smiles)
        scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=True) if mol else ""
        scaffolds[scaffold].append(i)
    
    scaffold_sets = sorted(list(scaffolds.values()), key=len, reverse=True)
    train_indices = []
    train_cutoff = train_frac * len(smiles_list)
    
    for scaffold_set in scaffold_sets:
        if len(train_indices) + len(scaffold_set) > train_cutoff:
            break
        train_indices.extend(scaffold_set)
    
    all_indices = set(range(len(smiles_list)))
    val_indices = list(all_indices - set(train_indices))
    return train_indices, val_indices

def load_data():
    """
    Loads Tox21 from Hugging Face Hub instead of local CSV.
    Matches leaderboard requirements for training verification.
    """
    logger.info("Fetching Tox21 dataset from Hugging Face Hub (ml-jku/tox21)...")
    dataset = load_dataset("ml-jku/tox21")
    
    # Merge splits for consistent processing or use them directly
    # Here we process the official training set
    df = dataset['train'].to_pandas()
    
    # Pre-process: ensure SMILES are valid
    valid_mask = df['smiles'].apply(lambda s: Chem.MolFromSmiles(s) is not None)
    df = df[valid_mask].reset_index(drop=True)
    
    logger.info(f"Computing descriptors for {len(df)} molecules...")
    all_global_features = []
    for s in df['smiles']:
        mol = Chem.MolFromSmiles(s)
        all_global_features.append(get_global_features(mol))
    
    all_global_features = np.array(all_global_features)
    
    # Scaffold Split
    train_idx, val_idx = scaffold_split(df['smiles'].tolist())
    
    # Fit/Apply Scaler
    scaler = StandardScaler()
    scaler.fit(all_global_features[train_idx])
    
    with open(SCALER_FILE, 'wb') as f:
        pickle.dump(scaler, f)
    
    all_global_features_scaled = scaler.transform(all_global_features)

    def format_subset(indices):
        data_list = []
        for original_idx in indices:
            row = df.iloc[original_idx]
            # Convert labels: leaderboard uses NaNs for missing data
            labels = row[TASKS].values.astype(np.float32)
            labels = np.nan_to_num(labels, nan=-1.0) # -1 signals missing for the loss function
            
            data_list.append({
                'smiles': row['smiles'],
                'labels': torch.tensor(labels, dtype=torch.float32),
                'global_features': torch.tensor(all_global_features_scaled[original_idx], dtype=torch.float32),
                'mol_id': f"mol_{original_idx}"
            })
        return data_list

    return format_subset(train_idx), format_subset(val_idx), scaler