Vaishnav14220
Orchestrate full ORD training pipeline in Space
29a351f
import random
from typing import Iterable, Dict, List, Iterator, Sequence, Any
from rdkit import Chem
from rdkit.Chem import MolToSmiles
def canonicalize(smiles: str) -> str:
"""Canonicalize a SMILES string using RDKit."""
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return smiles
return MolToSmiles(mol, canonical=True)
except Exception as e:
print(f"Error canonicalizing {smiles}: {e}")
return smiles
def join_smiles(items: Iterable[str]) -> str:
"""Join SMILES strings with '.' separator, removing duplicates and sorting."""
unique_smiles = sorted({s.strip() for s in items if s and s.strip()})
return ".".join(unique_smiles)
def _flatten_smiles(value: Any) -> List[str]:
"""Recursively flatten SMILES structures (lists, tuples, sets, dicts)."""
if value is None:
return []
if isinstance(value, str):
return [value]
if isinstance(value, dict):
# Some datasets nest SMILES under keys like "smiles"
if "smiles" in value:
return _flatten_smiles(value.get("smiles"))
return []
if isinstance(value, (list, tuple, set)):
flattened: List[str] = []
for item in value:
flattened.extend(_flatten_smiles(item))
return flattened
return [str(value)]
def _extract_reactants_products(example: Dict) -> (List[str], List[str]):
"""Extract reactant and product SMILES from multiple possible schema variants."""
reactants: List[str] = []
products: List[str] = []
# ORD protobuf-style structure
for input_block in example.get("inputs", []) or []:
if isinstance(input_block, dict):
reactants.extend(_flatten_smiles(input_block.get("components")))
for product_block in example.get("products", []) or []:
if isinstance(product_block, dict):
products.extend(_flatten_smiles(product_block))
# Flattened dataset fields (inputs_smiles/products_smiles)
if "inputs_smiles" in example:
reactants.extend(_flatten_smiles(example.get("inputs_smiles")))
if "products_smiles" in example:
products.extend(_flatten_smiles(example.get("products_smiles")))
return reactants, products
def forward_example(example: Dict) -> Dict:
"""Convert raw ORD example to forward synthesis (reactants -> products)."""
try:
reactants_raw, products_raw = _extract_reactants_products(example)
if not reactants_raw or not products_raw:
return {}
reactants_canonical = [canonicalize(x) for x in reactants_raw if x]
products_canonical = [canonicalize(x) for x in products_raw if x]
if not reactants_canonical or not products_canonical:
return {}
return {
"source": join_smiles(reactants_canonical),
"target": join_smiles(products_canonical),
}
except Exception as e:
print(f"Error processing forward example: {e}")
return {}
def retro_example(example: Dict) -> Dict:
"""Convert forward synthesis example to retrosynthesis (products -> reactants)."""
forward = forward_example(example)
if not forward:
return {}
return {
"source": forward["target"],
"target": forward["source"]
}
def split_dataset_indices(seed: int, train_ratio: float) -> Iterator[str]:
"""Generate random split assignments for dataset."""
random.seed(seed)
while True:
rnd = random.random()
if rnd < train_ratio:
yield "train"
elif rnd < train_ratio + (1 - train_ratio) / 2:
yield "validation"
else:
yield "test"
def deduplicate_examples(examples: List[Dict]) -> List[Dict]:
"""Remove duplicate reactions based on source>>target."""
seen = set()
unique = []
for ex in examples:
key = f"{ex.get('source', '')}>{ex.get('target', '')}"
if key not in seen and ex:
seen.add(key)
unique.append(ex)
return unique