import random from typing import Iterable, Dict, List, Iterator, Sequence, Any from rdkit import Chem from rdkit.Chem import MolToSmiles def canonicalize(smiles: str) -> str: """Canonicalize a SMILES string using RDKit.""" try: mol = Chem.MolFromSmiles(smiles) if mol is None: return smiles return MolToSmiles(mol, canonical=True) except Exception as e: print(f"Error canonicalizing {smiles}: {e}") return smiles def join_smiles(items: Iterable[str]) -> str: """Join SMILES strings with '.' separator, removing duplicates and sorting.""" unique_smiles = sorted({s.strip() for s in items if s and s.strip()}) return ".".join(unique_smiles) def _flatten_smiles(value: Any) -> List[str]: """Recursively flatten SMILES structures (lists, tuples, sets, dicts).""" if value is None: return [] if isinstance(value, str): return [value] if isinstance(value, dict): # Some datasets nest SMILES under keys like "smiles" if "smiles" in value: return _flatten_smiles(value.get("smiles")) return [] if isinstance(value, (list, tuple, set)): flattened: List[str] = [] for item in value: flattened.extend(_flatten_smiles(item)) return flattened return [str(value)] def _extract_reactants_products(example: Dict) -> (List[str], List[str]): """Extract reactant and product SMILES from multiple possible schema variants.""" reactants: List[str] = [] products: List[str] = [] # ORD protobuf-style structure for input_block in example.get("inputs", []) or []: if isinstance(input_block, dict): reactants.extend(_flatten_smiles(input_block.get("components"))) for product_block in example.get("products", []) or []: if isinstance(product_block, dict): products.extend(_flatten_smiles(product_block)) # Flattened dataset fields (inputs_smiles/products_smiles) if "inputs_smiles" in example: reactants.extend(_flatten_smiles(example.get("inputs_smiles"))) if "products_smiles" in example: products.extend(_flatten_smiles(example.get("products_smiles"))) return reactants, products def forward_example(example: Dict) -> Dict: """Convert raw ORD example to forward synthesis (reactants -> products).""" try: reactants_raw, products_raw = _extract_reactants_products(example) if not reactants_raw or not products_raw: return {} reactants_canonical = [canonicalize(x) for x in reactants_raw if x] products_canonical = [canonicalize(x) for x in products_raw if x] if not reactants_canonical or not products_canonical: return {} return { "source": join_smiles(reactants_canonical), "target": join_smiles(products_canonical), } except Exception as e: print(f"Error processing forward example: {e}") return {} def retro_example(example: Dict) -> Dict: """Convert forward synthesis example to retrosynthesis (products -> reactants).""" forward = forward_example(example) if not forward: return {} return { "source": forward["target"], "target": forward["source"] } def split_dataset_indices(seed: int, train_ratio: float) -> Iterator[str]: """Generate random split assignments for dataset.""" random.seed(seed) while True: rnd = random.random() if rnd < train_ratio: yield "train" elif rnd < train_ratio + (1 - train_ratio) / 2: yield "validation" else: yield "test" def deduplicate_examples(examples: List[Dict]) -> List[Dict]: """Remove duplicate reactions based on source>>target.""" seen = set() unique = [] for ex in examples: key = f"{ex.get('source', '')}>{ex.get('target', '')}" if key not in seen and ex: seen.add(key) unique.append(ex) return unique