File size: 2,158 Bytes
94b1553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Sequence

import numpy as np
import pandas as pd
from datasets import load_dataset
from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize

from .constants import CANONICAL_SMILES_COLUMN

@dataclass
class MoleculeBatch:
    mols: List[Chem.Mol]
    mask: np.ndarray
    canonical_smiles: List[str]


def load_tox21_dataset(token: str | None, dataset_name: str) -> Dict[str, pd.DataFrame]:
    """Load dataset splits from Hugging Face into pandas DataFrames."""
    dataset = load_dataset(dataset_name, token=token)
    splits: Dict[str, pd.DataFrame] = {}
    for split_name in dataset.keys():
        splits[split_name] = dataset[split_name].to_pandas()
    return splits


def standardize_smiles(smiles: Sequence[str]) -> MoleculeBatch:
    """Standardize SMILES strings and return RDKit molecules with canonical SMILES."""
    tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
    cleanup_params = rdMolStandardize.CleanupParameters()

    mols: List[Chem.Mol] = []
    canonical_smiles: List[str] = []
    mask = np.zeros(len(smiles), dtype=bool)

    for idx, smi in enumerate(smiles):
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue

            mol = rdMolStandardize.Cleanup(mol, cleanup_params)
            mol = tautomer_enumerator.Canonicalize(mol)
            canonical = Chem.MolToSmiles(mol)
            mol = Chem.MolFromSmiles(canonical)
            if mol is None:
                continue

            mols.append(mol)
            canonical_smiles.append(canonical)
            mask[idx] = True
        except Exception:
            continue

    return MoleculeBatch(mols=mols, mask=mask, canonical_smiles=canonical_smiles)


def filter_dataframe_by_mask(df: pd.DataFrame, mask: np.ndarray, canonical_smiles: Sequence[str]) -> pd.DataFrame:
    """Apply mask to dataframe and append canonical SMILES column."""
    clean_df = df.loc[mask].copy().reset_index(drop=True)
    clean_df[CANONICAL_SMILES_COLUMN] = canonical_smiles
    return clean_df