from __future__ import annotations from dataclasses import dataclass from typing import Dict, List, Sequence import numpy as np import pandas as pd from datasets import load_dataset from rdkit import Chem from rdkit.Chem.MolStandardize import rdMolStandardize from .constants import CANONICAL_SMILES_COLUMN @dataclass class MoleculeBatch: mols: List[Chem.Mol] mask: np.ndarray canonical_smiles: List[str] def load_tox21_dataset(token: str | None, dataset_name: str) -> Dict[str, pd.DataFrame]: """Load dataset splits from Hugging Face into pandas DataFrames.""" dataset = load_dataset(dataset_name, token=token) splits: Dict[str, pd.DataFrame] = {} for split_name in dataset.keys(): splits[split_name] = dataset[split_name].to_pandas() return splits def standardize_smiles(smiles: Sequence[str]) -> MoleculeBatch: """Standardize SMILES strings and return RDKit molecules with canonical SMILES.""" tautomer_enumerator = rdMolStandardize.TautomerEnumerator() cleanup_params = rdMolStandardize.CleanupParameters() mols: List[Chem.Mol] = [] canonical_smiles: List[str] = [] mask = np.zeros(len(smiles), dtype=bool) for idx, smi in enumerate(smiles): try: mol = Chem.MolFromSmiles(smi) if mol is None: continue mol = rdMolStandardize.Cleanup(mol, cleanup_params) mol = tautomer_enumerator.Canonicalize(mol) canonical = Chem.MolToSmiles(mol) mol = Chem.MolFromSmiles(canonical) if mol is None: continue mols.append(mol) canonical_smiles.append(canonical) mask[idx] = True except Exception: continue return MoleculeBatch(mols=mols, mask=mask, canonical_smiles=canonical_smiles) def filter_dataframe_by_mask(df: pd.DataFrame, mask: np.ndarray, canonical_smiles: Sequence[str]) -> pd.DataFrame: """Apply mask to dataframe and append canonical SMILES column.""" clean_df = df.loc[mask].copy().reset_index(drop=True) clean_df[CANONICAL_SMILES_COLUMN] = canonical_smiles return clean_df