|
|
import random |
|
|
from typing import Iterable, Dict, List, Iterator, Sequence, Any |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import MolToSmiles |
|
|
|
|
|
|
|
|
def canonicalize(smiles: str) -> str: |
|
|
"""Canonicalize a SMILES string using RDKit.""" |
|
|
try: |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
return smiles |
|
|
return MolToSmiles(mol, canonical=True) |
|
|
except Exception as e: |
|
|
print(f"Error canonicalizing {smiles}: {e}") |
|
|
return smiles |
|
|
|
|
|
|
|
|
def join_smiles(items: Iterable[str]) -> str: |
|
|
"""Join SMILES strings with '.' separator, removing duplicates and sorting.""" |
|
|
unique_smiles = sorted({s.strip() for s in items if s and s.strip()}) |
|
|
return ".".join(unique_smiles) |
|
|
|
|
|
|
|
|
def _flatten_smiles(value: Any) -> List[str]: |
|
|
"""Recursively flatten SMILES structures (lists, tuples, sets, dicts).""" |
|
|
if value is None: |
|
|
return [] |
|
|
if isinstance(value, str): |
|
|
return [value] |
|
|
if isinstance(value, dict): |
|
|
|
|
|
if "smiles" in value: |
|
|
return _flatten_smiles(value.get("smiles")) |
|
|
return [] |
|
|
if isinstance(value, (list, tuple, set)): |
|
|
flattened: List[str] = [] |
|
|
for item in value: |
|
|
flattened.extend(_flatten_smiles(item)) |
|
|
return flattened |
|
|
return [str(value)] |
|
|
|
|
|
|
|
|
def _extract_reactants_products(example: Dict) -> (List[str], List[str]): |
|
|
"""Extract reactant and product SMILES from multiple possible schema variants.""" |
|
|
reactants: List[str] = [] |
|
|
products: List[str] = [] |
|
|
|
|
|
|
|
|
for input_block in example.get("inputs", []) or []: |
|
|
if isinstance(input_block, dict): |
|
|
reactants.extend(_flatten_smiles(input_block.get("components"))) |
|
|
|
|
|
for product_block in example.get("products", []) or []: |
|
|
if isinstance(product_block, dict): |
|
|
products.extend(_flatten_smiles(product_block)) |
|
|
|
|
|
|
|
|
if "inputs_smiles" in example: |
|
|
reactants.extend(_flatten_smiles(example.get("inputs_smiles"))) |
|
|
if "products_smiles" in example: |
|
|
products.extend(_flatten_smiles(example.get("products_smiles"))) |
|
|
|
|
|
return reactants, products |
|
|
|
|
|
|
|
|
def forward_example(example: Dict) -> Dict: |
|
|
"""Convert raw ORD example to forward synthesis (reactants -> products).""" |
|
|
try: |
|
|
reactants_raw, products_raw = _extract_reactants_products(example) |
|
|
|
|
|
if not reactants_raw or not products_raw: |
|
|
return {} |
|
|
|
|
|
reactants_canonical = [canonicalize(x) for x in reactants_raw if x] |
|
|
products_canonical = [canonicalize(x) for x in products_raw if x] |
|
|
|
|
|
if not reactants_canonical or not products_canonical: |
|
|
return {} |
|
|
|
|
|
return { |
|
|
"source": join_smiles(reactants_canonical), |
|
|
"target": join_smiles(products_canonical), |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Error processing forward example: {e}") |
|
|
return {} |
|
|
|
|
|
|
|
|
def retro_example(example: Dict) -> Dict: |
|
|
"""Convert forward synthesis example to retrosynthesis (products -> reactants).""" |
|
|
forward = forward_example(example) |
|
|
if not forward: |
|
|
return {} |
|
|
|
|
|
return { |
|
|
"source": forward["target"], |
|
|
"target": forward["source"] |
|
|
} |
|
|
|
|
|
|
|
|
def split_dataset_indices(seed: int, train_ratio: float) -> Iterator[str]: |
|
|
"""Generate random split assignments for dataset.""" |
|
|
random.seed(seed) |
|
|
while True: |
|
|
rnd = random.random() |
|
|
if rnd < train_ratio: |
|
|
yield "train" |
|
|
elif rnd < train_ratio + (1 - train_ratio) / 2: |
|
|
yield "validation" |
|
|
else: |
|
|
yield "test" |
|
|
|
|
|
|
|
|
def deduplicate_examples(examples: List[Dict]) -> List[Dict]: |
|
|
"""Remove duplicate reactions based on source>>target.""" |
|
|
seen = set() |
|
|
unique = [] |
|
|
|
|
|
for ex in examples: |
|
|
key = f"{ex.get('source', '')}>{ex.get('target', '')}" |
|
|
if key not in seen and ex: |
|
|
seen.add(key) |
|
|
unique.append(ex) |
|
|
|
|
|
return unique |
|
|
|