diff --git a/SynTool/__init__.py b/SynTool/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..2a04536789e8e3660c5758c092982bd4507a79f5 --- /dev/null +++ b/SynTool/__init__.py @@ -0,0 +1,3 @@ +from .mcts import * + +__all__ = ["Tree"] diff --git a/SynTool/chem/__init__.py b/SynTool/chem/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SynTool/chem/__pycache__/__init__.cpython-310.pyc b/SynTool/chem/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46b692ffe41b12b6ef257284ff49a790767a925f Binary files /dev/null and b/SynTool/chem/__pycache__/__init__.cpython-310.pyc differ diff --git a/SynTool/chem/__pycache__/reaction.cpython-310.pyc b/SynTool/chem/__pycache__/reaction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd5c09272e6aff74cc2a930d6ce2685506ff7e3b Binary files /dev/null and b/SynTool/chem/__pycache__/reaction.cpython-310.pyc differ diff --git a/SynTool/chem/__pycache__/retron.cpython-310.pyc b/SynTool/chem/__pycache__/retron.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e56eac834ed51580ae0305bbdf19250aeabe0c5c Binary files /dev/null and b/SynTool/chem/__pycache__/retron.cpython-310.pyc differ diff --git a/SynTool/chem/__pycache__/utils.cpython-310.pyc b/SynTool/chem/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbdaff60c4fac107754d6d17e6dd8c1b12fc246e Binary files /dev/null and b/SynTool/chem/__pycache__/utils.cpython-310.pyc differ diff --git a/SynTool/chem/data/__init__.py b/SynTool/chem/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SynTool/chem/data/__pycache__/__init__.cpython-310.pyc b/SynTool/chem/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d722a048daf89b7f5338f588d495aee68bb4c8be Binary files /dev/null and b/SynTool/chem/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/SynTool/chem/data/__pycache__/cleaning.cpython-310.pyc b/SynTool/chem/data/__pycache__/cleaning.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8af0739bcddd86bcae4842a6b3813273382b488 Binary files /dev/null and b/SynTool/chem/data/__pycache__/cleaning.cpython-310.pyc differ diff --git a/SynTool/chem/data/__pycache__/filtering.cpython-310.pyc b/SynTool/chem/data/__pycache__/filtering.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c36901c582fa4586c1745eb6156c37bc8d5103c Binary files /dev/null and b/SynTool/chem/data/__pycache__/filtering.cpython-310.pyc differ diff --git a/SynTool/chem/data/__pycache__/mapping.cpython-310.pyc b/SynTool/chem/data/__pycache__/mapping.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..976460e6a2ccc64043d93a7d288802dcd78f6a0e Binary files /dev/null and b/SynTool/chem/data/__pycache__/mapping.cpython-310.pyc differ diff --git a/SynTool/chem/data/__pycache__/standardizer.cpython-310.pyc b/SynTool/chem/data/__pycache__/standardizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a3f9a0bede2b466bf554d247d508192242742fd Binary files /dev/null and b/SynTool/chem/data/__pycache__/standardizer.cpython-310.pyc differ diff --git a/SynTool/chem/data/cleaning.py b/SynTool/chem/data/cleaning.py new file mode 100644 index 0000000000000000000000000000000000000000..344d9f765dbc01dde57c919234093264a2d660d0 --- /dev/null +++ b/SynTool/chem/data/cleaning.py @@ -0,0 +1,124 @@ +import os +from multiprocessing import Queue, Process, Manager, Value +from logging import getLogger, Logger +from tqdm import tqdm +from CGRtools.containers import ReactionContainer + +from .standardizer import Standardizer +from SynTool.utils.files import ReactionReader, ReactionWriter +from SynTool.utils.config import ReactionStandardizationConfig + + +def cleaner(reaction: ReactionContainer, logger: Logger, config: ReactionStandardizationConfig): + """ + Standardize a reaction according to external script + + :param reaction: ReactionContainer to clean/standardize + :param logger: Logger - to avoid writing log + :param config: ReactionStandardizationConfig + :return: ReactionContainer or empty list + """ + standardizer = Standardizer(id_tag='Reaction_ID', + action_on_isotopes=2, + skip_tautomerize=True, + skip_errors=config.skip_errors, + keep_unbalanced_ions=config.keep_unbalanced_ions, + keep_reagents=config.keep_reagents, + ignore_mapping=config.ignore_mapping, + logger=logger) + return standardizer.standardize(reaction) + + +def worker_cleaner(to_clean: Queue, to_write: Queue, config: ReactionStandardizationConfig): + """ + Launches standardizations using the Queue to_clean. Fills the to_write Queue with results + + :param to_clean: Queue of reactions to clean/standardize + :param to_write: Standardized outputs to write + :param config: ReactionStandardizationConfig + :return: None + """ + logger = getLogger() + logger.disabled = True + while True: + raw_reaction = to_clean.get() + if raw_reaction == "Quit": + break + res = cleaner(raw_reaction, logger, config) + to_write.put(res) + logger.disabled = False + + +def cleaner_writer(output_file: str, to_write: Queue, cleaned_nb: Value, remove_old=True): + """ + Writes in output file the standardized reactions + + :param output_file: output file path + :param to_write: Standardized ReactionContainer to write + :param cleaned_nb: number of final reactions + :param remove_old: whenever to remove or not an already existing file + """ + + if remove_old and os.path.isfile(output_file): + os.remove(output_file) + + counter = 0 + seen_reactions = [] + with ReactionWriter(output_file) as out: + while True: + res = to_write.get() + if res: + if res == "Quit": + cleaned_nb.set(counter) + break + elif isinstance(res, ReactionContainer): + smi = format(res, "m") + if smi not in seen_reactions: + out.write(res) + counter += 1 + seen_reactions.append(smi) + + +def reactions_cleaner(config: ReactionStandardizationConfig, + input_file: str, output_file: str, num_cpus: int, batch_prep_size: int = 100): + """ + Writes in output file the standardized reactions + + :param config: + :param input_file: input RDF file path + :param output_file: output RDF file path + :param num_cpus: number of CPU to be parallelized + :param batch_prep_size: size of each batch per CPU + """ + with Manager() as m: + to_clean = m.Queue(maxsize=num_cpus * batch_prep_size) + to_write = m.Queue(maxsize=batch_prep_size) + cleaned_nb = m.Value(int, 0) + + writer = Process(target=cleaner_writer, args=(output_file, to_write, cleaned_nb)) + writer.start() + + workers = [] + for _ in range(num_cpus - 2): + w = Process(target=worker_cleaner, args=(to_clean, to_write, config)) + w.start() + workers.append(w) + + n = 0 + with ReactionReader(input_file) as reactions: + for raw_reaction in tqdm(reactions): + if 'Reaction_ID' not in raw_reaction.meta: + raw_reaction.meta['Reaction_ID'] = n + to_clean.put(raw_reaction) + n += 1 + + for _ in workers: + to_clean.put("Quit") + for w in workers: + w.join() + + to_write.put("Quit") + writer.join() + + print(f'Initial number of reactions: {n}'), + print(f'Removed number of reactions: {n - cleaned_nb.get()}') diff --git a/SynTool/chem/data/filtering.py b/SynTool/chem/data/filtering.py new file mode 100644 index 0000000000000000000000000000000000000000..3a66fa7757e725a5ca882c0549b4278fb14041e7 --- /dev/null +++ b/SynTool/chem/data/filtering.py @@ -0,0 +1,917 @@ +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, Tuple, Dict, Any, Optional +from tqdm.auto import tqdm + +import numpy as np +import ray +import yaml +from CGRtools.containers import ReactionContainer, MoleculeContainer, CGRContainer +from StructureFingerprint import MorganFingerprint + +from SynTool.utils.files import ReactionReader, ReactionWriter +from SynTool.chem.utils import remove_small_molecules, rebalance_reaction, remove_reagents +from SynTool.utils.config import ConfigABC, convert_config_to_dict + + +@dataclass +class CompeteProductsConfig(ConfigABC): + fingerprint_tanimoto_threshold: float = 0.3 + mcs_tanimoto_threshold: float = 0.6 + + @staticmethod + def from_dict(config_dict: Dict[str, Any]): + """Create an instance of CompeteProductsConfig from a dictionary.""" + return CompeteProductsConfig(**config_dict) + + @staticmethod + def from_yaml(file_path: str): + """Deserialize a YAML file into a CompeteProductsConfig object.""" + with open(file_path, "r") as file: + config_dict = yaml.safe_load(file) + return CompeteProductsConfig.from_dict(config_dict) + + def _validate_params(self, params: Dict[str, Any]): + """Validate configuration parameters.""" + if not isinstance(params.get("fingerprint_tanimoto_threshold"), float) \ + or not (0 <= params["fingerprint_tanimoto_threshold"] <= 1): + raise ValueError("Invalid 'fingerprint_tanimoto_threshold'; expected a float between 0 and 1") + + if not isinstance(params.get("mcs_tanimoto_threshold"), float) \ + or not (0 <= params["mcs_tanimoto_threshold"] <= 1): + raise ValueError("Invalid 'mcs_tanimoto_threshold'; expected a float between 0 and 1") + + +class CompeteProductsChecker: + """Checks if there are compete reactions.""" + + def __init__( + self, + fingerprint_tanimoto_threshold: float = 0.3, + mcs_tanimoto_threshold: float = 0.6, + ): + self.fingerprint_tanimoto_threshold = fingerprint_tanimoto_threshold + self.mcs_tanimoto_threshold = mcs_tanimoto_threshold + + @staticmethod + def from_config(config: CompeteProductsConfig): + """Creates an instance of CompeteProductsChecker from a configuration object.""" + return CompeteProductsChecker( + config.fingerprint_tanimoto_threshold, config.mcs_tanimoto_threshold + ) + + def __call__(self, reaction: ReactionContainer) -> bool: + """ + Returns True if the reaction has competing products, else False + + :param reaction: input reaction + :return: True or False + """ + mf = MorganFingerprint() + is_compete = False + + # Check for compete products using both fingerprint similarity and maximum common substructure (MCS) similarity + for mol in reaction.reagents: + for other_mol in reaction.products: + if len(mol) > 6 and len(other_mol) > 6: + # Compute fingerprint similarity + molf = mf.transform([mol]) + other_molf = mf.transform([other_mol]) + fingerprint_tanimoto = tanimoto_kernel(molf, other_molf)[0][0] + + # If fingerprint similarity is high enough, check for MCS similarity + if fingerprint_tanimoto > self.fingerprint_tanimoto_threshold: + try: + # Find the maximum common substructure (MCS) and compute its size + clique_size = len(next(mol.get_mcs_mapping(other_mol, limit=100))) + + # Calculate MCS similarity based on MCS size + mcs_tanimoto = clique_size / (len(mol) + len(other_mol) - clique_size) + + # If MCS similarity is also high enough, mark the reaction as having compete products + if mcs_tanimoto > self.mcs_tanimoto_threshold: + is_compete = True + break + except StopIteration: + continue + + return is_compete + + +@dataclass +class DynamicBondsConfig(ConfigABC): + min_bonds_number: int = 1 + max_bonds_number: int = 6 + + @staticmethod + def from_dict(config_dict: Dict[str, Any]): + """Create an instance of DynamicBondsConfig from a dictionary.""" + return DynamicBondsConfig(**config_dict) + + @staticmethod + def from_yaml(file_path: str): + """Deserialize a YAML file into a DynamicBondsConfig object.""" + with open(file_path, "r") as file: + config_dict = yaml.safe_load(file) + return DynamicBondsConfig.from_dict(config_dict) + + def _validate_params(self, params: Dict[str, Any]): + """Validate configuration parameters.""" + if not isinstance(params.get("min_bonds_number"), int) \ + or params["min_bonds_number"] < 0: + raise ValueError( + "Invalid 'min_bonds_number'; expected a non-negative integer") + + if not isinstance(params.get("max_bonds_number"), int) \ + or params["max_bonds_number"] < 0: + raise ValueError("Invalid 'max_bonds_number'; expected a non-negative integer") + + if params["min_bonds_number"] > params["max_bonds_number"]: + raise ValueError("'min_bonds_number' cannot be greater than 'max_bonds_number'") + + +class DynamicBondsChecker: + """Checks if there is an unacceptable number of dynamic bonds in CGR.""" + + def __init__(self, min_bonds_number: int = 1, max_bonds_number: int = 6): + self.min_bonds_number = min_bonds_number + self.max_bonds_number = max_bonds_number + + @staticmethod + def from_config(config: DynamicBondsConfig): + """Creates an instance of DynamicBondsChecker from a configuration object.""" + return DynamicBondsChecker(config.min_bonds_number, config.max_bonds_number) + + def __call__(self, reaction: ReactionContainer) -> bool: + cgr = ~reaction + return not (self.min_bonds_number <= len(cgr.center_bonds) <= self.max_bonds_number) + + +@dataclass +class SmallMoleculesConfig(ConfigABC): + limit: int = 6 + + @staticmethod + def from_dict(config_dict: Dict[str, Any]): + """Create an instance of SmallMoleculesConfig from a dictionary.""" + return SmallMoleculesConfig(**config_dict) + + @staticmethod + def from_yaml(file_path: str): + """Deserialize a YAML file into a SmallMoleculesConfig object.""" + with open(file_path, "r") as file: + config_dict = yaml.safe_load(file) + return SmallMoleculesConfig.from_dict(config_dict) + + def _validate_params(self, params: Dict[str, Any]): + """Validate configuration parameters.""" + if not isinstance(params.get("limit"), int) or params["limit"] < 1: + raise ValueError("Invalid 'limit'; expected a positive integer") + + +class SmallMoleculesChecker: + """Checks if there are only small molecules in the reaction or if there is only one small reactant or product.""" + + def __init__(self, limit: int = 6): + self.limit = limit + + @staticmethod + def from_config(config: SmallMoleculesConfig): + """Creates an instance of SmallMoleculesChecker from a configuration object.""" + return SmallMoleculesChecker(config.limit) + + def __call__(self, reaction: ReactionContainer) -> bool: + if (len(reaction.reactants) == 1 and self.are_only_small_molecules(reaction.reactants)) \ + or (len(reaction.products) == 1 and self.are_only_small_molecules(reaction.products)) \ + or (self.are_only_small_molecules(reaction.reactants) and self.are_only_small_molecules(reaction.products)): + return True + return False + + def are_only_small_molecules(self, molecules: Iterable[MoleculeContainer]) -> bool: + """Checks if all molecules in the given iterable are small molecules.""" + return all(len(molecule) <= self.limit for molecule in molecules) + + +@dataclass +class CGRConnectedComponentsConfig: + pass + + +class CGRConnectedComponentsChecker: + """Allows to check if CGR contains unrelated components (without reagents).""" + + @staticmethod + def from_config(config: CGRConnectedComponentsConfig): # TODO config class not used + """Creates an instance of CGRConnectedComponentsChecker from a configuration object.""" + return CGRConnectedComponentsChecker() + + def __call__(self, reaction: ReactionContainer) -> bool: + tmp_reaction = ReactionContainer(reaction.reactants, reaction.products) + cgr = ~tmp_reaction + return cgr.connected_components_count > 1 + + +@dataclass +class RingsChangeConfig: + pass + + +class RingsChangeChecker: + """Allows to check if there is changing rings number in the reaction.""" + + @staticmethod + def from_config(config: RingsChangeConfig): # TODO config class not used + """Creates an instance of RingsChecker from a configuration object.""" + return RingsChangeChecker() + + def __call__(self, reaction: ReactionContainer): + """ + Returns True if there are valence mistakes in the reaction or there is a reaction with mismatch numbers of all + rings or aromatic rings in reactants and products (reaction in rings) + + :param reaction: input reaction + :return: True or False + """ + + reaction.kekule() + reaction.thiele() + r_rings, r_arom_rings = self._calc_rings(reaction.reactants) + p_rings, p_arom_rings = self._calc_rings(reaction.products) + if (r_arom_rings != p_arom_rings) or (r_rings != p_rings): + return True + else: + return False + + @staticmethod + def _calc_rings(molecules: Iterable) -> Tuple[int, int]: + """ + Calculates number of all rings and number of aromatic rings in molecules + + :param molecules: set of molecules + :return: number of all rings and number of aromatic rings in molecules + """ + rings, arom_rings = 0, 0 + for mol in molecules: + rings += mol.rings_count + arom_rings += len(mol.aromatic_rings) + return rings, arom_rings + + +@dataclass +class StrangeCarbonsConfig: + # Currently empty, but can be extended in the future if needed + pass + + +class StrangeCarbonsChecker: + """Checks if there are 'strange' carbons in the reaction.""" + + @staticmethod + def from_config(config: StrangeCarbonsConfig): # TODO config class not used + """Creates an instance of StrangeCarbonsChecker from a configuration object.""" + return StrangeCarbonsChecker() + + def __call__(self, reaction: ReactionContainer) -> bool: + for molecule in reaction.reactants + reaction.products: + atoms_types = {a.atomic_symbol for _, a in molecule.atoms()} # atoms types in molecule + if len(atoms_types) == 1 and atoms_types.pop() == "C": + if len(molecule) == 1: # methane + return True + bond_types = {int(b) for _, _, b in molecule.bonds()} + if len(bond_types) == 1 and bond_types.pop() != 4: + return True # C molecules with only one type of bond (not aromatic) + return False + + +@dataclass +class NoReactionConfig: + # Currently empty, but can be extended in the future if needed + pass + + +class NoReactionChecker: + """Checks if there is no reaction in the provided reaction container.""" + + @staticmethod + def from_config(config: NoReactionConfig): # TODO config class not used + """Creates an instance of NoReactionChecker from a configuration object.""" + return NoReactionChecker() + + def __call__(self, reaction: ReactionContainer) -> bool: + cgr = ~reaction + return not cgr.center_atoms and not cgr.center_bonds + + +@dataclass +class MultiCenterConfig: + pass + + +class MultiCenterChecker: + """Checks if there is a multicenter reaction.""" + + @staticmethod + def from_config(config: MultiCenterConfig): # TODO config class not used + return MultiCenterChecker() + + def __call__(self, reaction: ReactionContainer) -> bool: + cgr = ~reaction + return len(cgr.centers_list) > 1 + + +@dataclass +class WrongCHBreakingConfig: + pass + + +class WrongCHBreakingChecker: + """Checks for incorrect C-C bond formation from breaking a C-H bond.""" + + @staticmethod + def from_config(config: WrongCHBreakingConfig): # TODO config class not used + return WrongCHBreakingChecker() + + def __call__(self, reaction: ReactionContainer) -> bool: + """ + Determines if a reaction involves incorrect C-C bond formation from breaking a C-H bond. + + :param reaction: The reaction to be checked. + :return: True if incorrect C-C bond formation is found, False otherwise. + """ + + reaction.kekule() + if reaction.check_valence(): + return False + reaction.thiele() + + copy_reaction = reaction.copy() + copy_reaction.explicify_hydrogens() + cgr = ~copy_reaction + reduced_cgr = cgr.augmented_substructure(cgr.center_atoms, deep=1) + + return self.is_wrong_c_h_breaking(reduced_cgr) + + @staticmethod + def is_wrong_c_h_breaking(cgr: CGRContainer) -> bool: + """ + Checks for incorrect C-C bond formation from breaking a C-H bond in a CGR. + :param cgr: The CGR with explicified hydrogens. + :return: True if incorrect C-C bond formation is found, False otherwise. + """ + for atom_id in cgr.center_atoms: + if cgr.atom(atom_id).atomic_symbol == "C": + is_c_h_breaking, is_c_c_formation = False, False + c_with_h_id, another_c_id = None, None + + for neighbour_id, bond in cgr._bonds[atom_id].items(): + neighbour = cgr.atom(neighbour_id) + + if ( + bond.order + and not bond.p_order + and neighbour.atomic_symbol == "H" + ): + is_c_h_breaking = True + c_with_h_id = atom_id + + elif ( + not bond.order + and bond.p_order + and neighbour.atomic_symbol == "C" + ): + is_c_c_formation = True + another_c_id = neighbour_id + + if is_c_h_breaking and is_c_c_formation: + # Check for presence of heteroatoms in the first environment of 2 bonding carbons + if any( + cgr.atom(neighbour_id).atomic_symbol not in ("C", "H") + for neighbour_id in cgr._bonds[c_with_h_id] + ) or any( + cgr.atom(neighbour_id).atomic_symbol not in ("C", "H") + for neighbour_id in cgr._bonds[another_c_id] + ): + return False + return True + + return False + + +@dataclass +class CCsp3BreakingConfig: + pass + + +class CCsp3BreakingChecker: + """Checks if there is C(sp3)-C bond breaking.""" + + @staticmethod + def from_config(config: CCsp3BreakingConfig): # TODO config class not used + return CCsp3BreakingChecker() + + def __call__(self, reaction: ReactionContainer) -> bool: + """ + Returns True if there is C(sp3)-C bonds breaking, else False + + :param reaction: input reaction + :return: True or False + """ + cgr = ~reaction + reaction_center = cgr.augmented_substructure(cgr.center_atoms, deep=1) + for atom_id, neighbour_id, bond in reaction_center.bonds(): + atom = reaction_center.atom(atom_id) + neighbour = reaction_center.atom(neighbour_id) + + is_bond_broken = bond.order is not None and bond.p_order is None + are_atoms_carbons = ( + atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C" + ) + is_atom_sp3 = atom.hybridization == 1 or neighbour.hybridization == 1 + + if is_bond_broken and are_atoms_carbons and is_atom_sp3: + return True + return False + + +@dataclass +class CCRingBreakingConfig: + pass + + +class CCRingBreakingChecker: + """Checks if a reaction involves ring C-C bond breaking.""" + + @staticmethod + def from_config(config: CCRingBreakingConfig): # TODO config class not used + return CCRingBreakingChecker() + + def __call__(self, reaction: ReactionContainer) -> bool: + """ + Returns True if the reaction involves ring C-C bond breaking, else False + + :param reaction: input reaction + :return: True or False + """ + cgr = ~reaction + + # Extract reactants' center atoms and their rings + reactants_center_atoms = {} + reactants_rings = set() + for reactant in reaction.reactants: + reactants_rings.update(reactant.sssr) + for n, atom in reactant.atoms(): + if n in cgr.center_atoms: + reactants_center_atoms[n] = atom + + # Identify reaction center based on center atoms + reaction_center = cgr.augmented_substructure(atoms=cgr.center_atoms, deep=0) + + # Iterate over bonds in the reaction center and check for ring C-C bond breaking + for atom_id, neighbour_id, bond in reaction_center.bonds(): + try: + # Retrieve corresponding atoms from reactants + atom = reactants_center_atoms[atom_id] + neighbour = reactants_center_atoms[neighbour_id] + except KeyError: + continue + else: + # Check if the bond is broken and both atoms are carbons in rings of size 5, 6, or 7 + is_bond_broken = (bond.order is not None) and (bond.p_order is None) + are_atoms_carbons = ( + atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C" + ) + are_atoms_in_ring = ( + set(atom.ring_sizes).intersection({5, 6, 7}) + and set(neighbour.ring_sizes).intersection({5, 6, 7}) + and any( + atom_id in ring and neighbour_id in ring + for ring in reactants_rings + ) + ) + + # If all conditions are met, indicate ring C-C bond breaking + if is_bond_broken and are_atoms_carbons and are_atoms_in_ring: + return True + + return False + + +@dataclass +class ReactionCheckConfig(ConfigABC): + """ + Configuration class for reaction checks, inheriting from ConfigABC. + + This class manages configuration settings for various reaction checkers, including paths, file formats, + and checker-specific parameters. + + Attributes: + dynamic_bonds_config: Configuration for dynamic bonds checking. + small_molecules_config: Configuration for small molecules checking. + strange_carbons_config: Configuration for strange carbons checking. + compete_products_config: Configuration for competing products checking. + cgr_connected_components_config: Configuration for CGR connected components checking. + rings_change_config: Configuration for rings change checking. + no_reaction_config: Configuration for no reaction checking. + multi_center_config: Configuration for multi-center checking. + wrong_ch_breaking_config: Configuration for wrong C-H breaking checking. + cc_sp3_breaking_config: Configuration for CC sp3 breaking checking. + cc_ring_breaking_config: Configuration for CC ring breaking checking. + """ + + # Configuration for reaction checkers + dynamic_bonds_config: Optional[DynamicBondsConfig] = None + small_molecules_config: Optional[SmallMoleculesConfig] = None + strange_carbons_config: Optional[StrangeCarbonsConfig] = None + compete_products_config: Optional[CompeteProductsConfig] = None + cgr_connected_components_config: Optional[CGRConnectedComponentsConfig] = None + rings_change_config: Optional[RingsChangeConfig] = None + no_reaction_config: Optional[NoReactionConfig] = None + multi_center_config: Optional[MultiCenterConfig] = None + wrong_ch_breaking_config: Optional[WrongCHBreakingConfig] = None + cc_sp3_breaking_config: Optional[CCsp3BreakingConfig] = None + cc_ring_breaking_config: Optional[CCRingBreakingConfig] = None + + # Other configuration parameters + rebalance_reaction: bool = False + remove_reagents: bool = True + reagents_max_size: int = 7 + remove_small_molecules: bool = False + small_molecules_max_size: int = 6 + + def to_dict(self): + """ + Converts the configuration into a dictionary. + """ + config_dict = { + "dynamic_bonds_config": convert_config_to_dict( + self.dynamic_bonds_config, DynamicBondsConfig + ), + "small_molecules_config": convert_config_to_dict( + self.small_molecules_config, SmallMoleculesConfig + ), + "compete_products_config": convert_config_to_dict( + self.compete_products_config, CompeteProductsConfig + ), + "cgr_connected_components_config": {} + if self.cgr_connected_components_config is not None + else None, + "rings_change_config": {} if self.rings_change_config is not None else None, + "strange_carbons_config": {} + if self.strange_carbons_config is not None + else None, + "no_reaction_config": {} if self.no_reaction_config is not None else None, + "multi_center_config": {} if self.multi_center_config is not None else None, + "wrong_ch_breaking_config": {} + if self.wrong_ch_breaking_config is not None + else None, + "cc_sp3_breaking_config": {} + if self.cc_sp3_breaking_config is not None + else None, + "cc_ring_breaking_config": {} + if self.cc_ring_breaking_config is not None + else None, + "rebalance_reaction": self.rebalance_reaction, + "remove_reagents": self.remove_reagents, + "reagents_max_size": self.reagents_max_size, + "remove_small_molecules": self.remove_small_molecules, + "small_molecules_max_size": self.small_molecules_max_size, + } + + filtered_config_dict = {k: v for k, v in config_dict.items() if v is not None} + + return filtered_config_dict + + @staticmethod + def from_dict(config_dict: Dict[str, Any]): + """ + Create an instance of ReactionCheckConfig from a dictionary. + """ + # Instantiate configuration objects if their corresponding dictionary is present + dynamic_bonds_config = ( + DynamicBondsConfig(**config_dict["dynamic_bonds_config"]) + if "dynamic_bonds_config" in config_dict + else None + ) + small_molecules_config = ( + SmallMoleculesConfig(**config_dict["small_molecules_config"]) + if "small_molecules_config" in config_dict + else None + ) + compete_products_config = ( + CompeteProductsConfig(**config_dict["compete_products_config"]) + if "compete_products_config" in config_dict + else None + ) + cgr_connected_components_config = ( + CGRConnectedComponentsConfig() + if "cgr_connected_components_config" in config_dict + else None + ) + rings_change_config = ( + RingsChangeConfig() + if "rings_change_config" in config_dict + else None + ) + strange_carbons_config = ( + StrangeCarbonsConfig() + if "strange_carbons_config" in config_dict + else None + ) + no_reaction_config = ( + NoReactionConfig() + if "no_reaction_config" in config_dict + else None + ) + multi_center_config = ( + MultiCenterConfig() + if "multi_center_config" in config_dict + else None + ) + wrong_ch_breaking_config = ( + WrongCHBreakingConfig() + if "wrong_ch_breaking_config" in config_dict + else None + ) + cc_sp3_breaking_config = ( + CCsp3BreakingConfig() + if "cc_sp3_breaking_config" in config_dict + else None + ) + cc_ring_breaking_config = ( + CCRingBreakingConfig() + if "cc_ring_breaking_config" in config_dict + else None + ) + + # Extract other simple configuration parameters + rebalance_reaction = config_dict.get("rebalance_reaction", False) + remove_reagents = config_dict.get("remove_reagents", True) + reagents_max_size = config_dict.get("reagents_max_size", 7) + remove_small_molecules = config_dict.get("remove_small_molecules", False) + small_molecules_max_size = config_dict.get("small_molecules_max_size", 6) + + return ReactionCheckConfig( + dynamic_bonds_config=dynamic_bonds_config, + small_molecules_config=small_molecules_config, + compete_products_config=compete_products_config, + cgr_connected_components_config=cgr_connected_components_config, + rings_change_config=rings_change_config, + strange_carbons_config=strange_carbons_config, + no_reaction_config=no_reaction_config, + multi_center_config=multi_center_config, + wrong_ch_breaking_config=wrong_ch_breaking_config, + cc_sp3_breaking_config=cc_sp3_breaking_config, + cc_ring_breaking_config=cc_ring_breaking_config, + rebalance_reaction=rebalance_reaction, + remove_reagents=remove_reagents, + reagents_max_size=reagents_max_size, + remove_small_molecules=remove_small_molecules, + small_molecules_max_size=small_molecules_max_size, + ) + + @staticmethod + def from_yaml(file_path): + """ + Deserializes a YAML file into a ReactionCheckConfig object. + """ + with open(file_path, "r") as file: + config_dict = yaml.safe_load(file) + return ReactionCheckConfig.from_dict(config_dict) + + def _validate_params(self, params: Dict[str, Any]): + if not isinstance(params["rebalance_reaction"], bool): + raise ValueError("rebalance_reaction must be a boolean.") + + if not isinstance(params["remove_reagents"], bool): + raise ValueError("remove_reagents must be a boolean.") + + if not isinstance(params["reagents_max_size"], int): + raise ValueError("reagents_max_size must be an int.") + + if not isinstance(params["remove_small_molecules"], bool): + raise ValueError("remove_small_molecules must be a boolean.") + + if not isinstance(params["small_molecules_max_size"], int): + raise ValueError("small_molecules_max_size must be an int.") + + def create_checkers(self): + checker_instances = [] + + if self.dynamic_bonds_config is not None: + checker_instances.append( + DynamicBondsChecker.from_config(self.dynamic_bonds_config) + ) + + if self.small_molecules_config is not None: + checker_instances.append( + SmallMoleculesChecker.from_config(self.small_molecules_config) + ) + + if self.strange_carbons_config is not None: + checker_instances.append( + StrangeCarbonsChecker.from_config(self.strange_carbons_config) + ) + + if self.compete_products_config is not None: + checker_instances.append( + CompeteProductsChecker.from_config(self.compete_products_config) + ) + + if self.cgr_connected_components_config is not None: + checker_instances.append( + CGRConnectedComponentsChecker.from_config( + self.cgr_connected_components_config + ) + ) + + if self.rings_change_config is not None: + checker_instances.append( + RingsChangeChecker.from_config(self.rings_change_config) + ) + + if self.no_reaction_config is not None: + checker_instances.append( + NoReactionChecker.from_config(self.no_reaction_config) + ) + + if self.multi_center_config is not None: + checker_instances.append( + MultiCenterChecker.from_config(self.multi_center_config) + ) + + if self.wrong_ch_breaking_config is not None: + checker_instances.append( + WrongCHBreakingChecker.from_config(self.wrong_ch_breaking_config) + ) + + if self.cc_sp3_breaking_config is not None: + checker_instances.append( + CCsp3BreakingChecker.from_config(self.cc_sp3_breaking_config) + ) + + if self.cc_ring_breaking_config is not None: + checker_instances.append( + CCRingBreakingChecker.from_config(self.cc_ring_breaking_config) + ) + + return checker_instances + + +def tanimoto_kernel(x, y): + """ + Calculate the Tanimoto coefficient between each element of arrays x and y. + """ + x = x.astype(np.float64) + y = y.astype(np.float64) + x_dot = np.dot(x, y.T) + x2 = np.sum(x**2, axis=1) + y2 = np.sum(y**2, axis=1) + + denominator = np.array([x2] * len(y2)).T + np.array([y2] * len(x2)) - x_dot + result = np.divide(x_dot, denominator, out=np.zeros_like(x_dot), where=denominator != 0) + + return result + + +def remove_file_if_exists(directory: Path, file_names): # TODO not used + for file_name in file_names: + file_path = directory / file_name + if file_path.is_file(): + file_path.unlink() + logging.warning(f"Removed {file_path}") + + +def filter_reaction(reaction: ReactionContainer, config: ReactionCheckConfig, checkers: list): + + is_filtered = False + if config.remove_small_molecules: + new_reaction = remove_small_molecules(reaction, number_of_atoms=config.small_molecules_max_size) + else: + new_reaction = reaction.copy() + + if new_reaction is None: + is_filtered = True + + if config.remove_reagents and not is_filtered: + new_reaction = remove_reagents( + new_reaction, + keep_reagents=True, + reagents_max_size=config.reagents_max_size, + ) + + if new_reaction is None: + is_filtered = True + new_reaction = reaction.copy() + # TODO you are specifying that if the reaction has only reagents, it is kept as it ? + + if not is_filtered: + if config.rebalance_reaction: + new_reaction = rebalance_reaction(new_reaction) + for checker in checkers: + try: # TODO CGRTools: ValueError: mapping of graphs is not disjoint + if checker(new_reaction): + # If checker returns True it means the reaction doesn't pass the check + new_reaction.meta["filtration_log"] = checker.__class__.__name__ + is_filtered = True + except: + is_filtered = True + + + + return is_filtered, new_reaction + + +@ray.remote +def process_batch(batch, config: ReactionCheckConfig, checkers): + results = [] + for index, reaction in batch: + try: # TODO CGRtools.exceptions.MappingError: atoms with number {52} not equal + is_filtered, processed_reaction = filter_reaction(reaction, config, checkers) + results.append((index, is_filtered, processed_reaction)) + except: + results.append((index, True, reaction)) + return results + + +def process_completed_batches(futures, result_file, pbar, treated: int = 0, passed_filters: int = 0): + done, _ = ray.wait(list(futures.keys()), num_returns=1) + completed_batch = ray.get(done[0]) + + # Write results of the completed batch to file + now_treated = 0 + for index, is_filtered, reaction in completed_batch: + now_treated += 1 + if not is_filtered: + result_file.write(reaction.meta['init_smiles']) + passed_filters += 1 + + # Remove completed future and update progress bar + del futures[done[0]] + pbar.update(now_treated) + treated += now_treated + + return treated, passed_filters + + +def filter_reactions( + config: ReactionCheckConfig, + reaction_database_path: str, + result_reactions_file_name: str = "reaction_data_filtered.smi", + append_results: bool = False, + num_cpus: int = 1, + batch_size: int = 100, +) -> None: + """ + Processes a database of chemical reactions, applying checks based on the provided configuration, + and writes the results to specified files. All configurations are provided by the ReactionCheckConfig object. + + :param config: ReactionCheckConfig object containing all configuration settings. + :param reaction_database_path: Path to the reaction database file. + :param result_reactions_file_name: Name for the file containing cleaned reactions. + :param append_results: Flag indicating whether to append results to existing files. + :param num_cpus: Number of CPUs to use for processing. + :param batch_size: Size of the batch for processing reactions. + :return: None. The function writes the processed reactions to specified RDF and pickle files. + Unique reactions are written if save_only_unique is True. + """ + + checkers = config.create_checkers() + + ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR) + max_concurrent_batches = num_cpus # Limit the number of concurrent batches + + with ReactionReader(reaction_database_path) as reactions, \ + ReactionWriter(result_reactions_file_name, append_results) as result_file: + + pbar = tqdm(reactions, leave=True) # TODO fix progress bars + + futures = {} + batch = [] + treated = filtered = 0 + for index, reaction in enumerate(reactions): + reaction.meta["reaction_index"] = index + batch.append((index, reaction)) + if len(batch) == batch_size: + future = process_batch.remote(batch, config, checkers) + futures[future] = None + batch = [] + + # Check and process completed tasks if we've reached the concurrency limit + while len(futures) >= max_concurrent_batches: + treated, filtered = process_completed_batches(futures, result_file, pbar, treated, filtered) + + # Process the last batch if it's not empty + if batch: + future = process_batch.remote(batch, config, checkers) + futures[future] = None + + # Process remaining batches + while futures: + treated, filtered = process_completed_batches(futures, result_file, pbar, treated, filtered) + + pbar.close() + + ray.shutdown() + print(f'Initial number of reactions: {treated}'), + print(f'Removed number of reactions: {treated - filtered}') diff --git a/SynTool/chem/data/mapping.py b/SynTool/chem/data/mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a4636c464b443bb26f4310778bd6001f01e793 --- /dev/null +++ b/SynTool/chem/data/mapping.py @@ -0,0 +1,96 @@ +from pathlib import Path +from os.path import splitext +from typing import Union +from tqdm import tqdm + +from chython import smiles, RDFRead, RDFWrite, ReactionContainer +from chython.exceptions import MappingError, IncorrectSmiles + +from SynTool.utils import path_type + + +def remove_reagents_and_map(rea: ReactionContainer, keep_reagent: bool = False) -> Union[ReactionContainer, None]: + """ + Maps atoms of the reaction using chytorch. + + :param rea: reaction to map + :type rea: ReactionContainer + :param keep_reagent: whenever to remove reagent or not + :type keep_reagent: bool + + :return: ReactionContainer or None + """ + try: + rea.reset_mapping() + except MappingError: + rea.reset_mapping() # Successive reset_mapping works + if not keep_reagent: + try: + rea.remove_reagents() + except: + return None + return rea + + +def remove_reagents_and_map_from_file(input_file: path_type, output_file: path_type, keep_reagent: bool = False) -> None: + """ + Reads a file of reactions and maps atoms of the reactions using chytorch. + + :param input_file: the path and name of the input file + :type input_file: path_type + :param output_file: the path and name of the output file + :type output_file: path_type + :param keep_reagent: whenever to remove reagent or not + :type keep_reagent: bool + + :return: None + """ + input_file = str(Path(input_file).resolve(strict=True)) + _, input_ext = splitext(input_file) + if input_ext == ".smi": + input_file = open(input_file, "r") + elif input_ext == ".rdf": + input_file = RDFRead(input_file, indexable=True) + else: + raise ValueError("File extension not recognized. File:", input_file, + "- Please use smi or rdf file") + enumerator = input_file if input_ext == ".rdf" else input_file.readlines() + + _, out_ext = splitext(output_file) + if out_ext == ".smi": + output_file = open(output_file, "w") + elif out_ext == ".rdf": + output_file = RDFWrite(output_file) + else: + raise ValueError("File extension not recognized. File:", output_file, + "- Please use smi or rdf file") + + mapping_errors = 0 + parsing_errors = 0 + for rea_raw in tqdm(enumerator): + try: + rea = smiles(rea_raw.strip('\n')) if input_ext == ".smi" else rea_raw + except IncorrectSmiles: + parsing_errors += 1 + continue + try: + rea_mapped = remove_reagents_and_map(rea, keep_reagent) + except MappingError: + try: + rea_mapped = remove_reagents_and_map(smiles(str(rea)), keep_reagent) + except MappingError: + mapping_errors += 1 + continue + if rea_mapped: + rea_output = format(rea, "m") + "\n" if out_ext == ".smi" else rea + output_file.write(rea_output) + else: + mapping_errors += 1 + + input_file.close() + output_file.close() + + if parsing_errors: + print(parsing_errors, "reactions couldn't be parsed") + if mapping_errors: + print(mapping_errors, "reactions couldn't be mapped") diff --git a/SynTool/chem/data/mapping.py.bk b/SynTool/chem/data/mapping.py.bk new file mode 100644 index 0000000000000000000000000000000000000000..0796c137004f49a60ab6fe4569691fd51645490e --- /dev/null +++ b/SynTool/chem/data/mapping.py.bk @@ -0,0 +1,90 @@ +from pathlib import Path +from os.path import splitext +from typing import Union +from tqdm import tqdm + +from chython import smiles, RDFRead, RDFWrite, ReactionContainer +from chython.exceptions import MappingError + +from Syntool.utils import path_type + + +def remove_reagents_and_map(rea: ReactionContainer) -> Union[ReactionContainer, None]: + """ + Maps atoms of the reaction using chytorch. + + :param rea: reaction to map + :type rea: ReactionContainer + + :return: ReactionContainer or None + """ + try: + rea.reset_mapping() + except MappingError: + rea.reset_mapping() + try: + rea.remove_reagents() + return rea + except: + # print("Error", str(rea)) + return None + + +def remove_reagents_and_map_from_file(input_file: path_type, output_file: path_type) -> None: + """ + Reads a file of reactions and maps atoms of the reactions using chytorch. + + :param input_file: the path and name of the input file + :type input_file: path_type + + :param output_file: the path and name of the output file + :type output_file: path_type + + :return: None + """ + input_file = str(Path(input_file).resolve(strict=True)) + _, input_ext = splitext(input_file) + if input_ext == ".smi": + input_file = open(input_file, "r") + elif input_ext == ".rdf": + input_file = RDFRead(input_file, indexable=True) + else: + raise ValueError("File extension not recognized. File:", input_file, + "- Please use smi or rdf file") + enumerator = input_file if input_ext == ".rdf" else input_file.readlines() + + _, out_ext = splitext(output_file) + if out_ext == ".smi": + output_file = open(output_file, "w") + elif out_ext == ".rdf": + output_file = RDFWrite(output_file) + else: + raise ValueError("File extension not recognized. File:", output_file, + "- Please use smi or rdf file") + + mapping_errors = 0 + parsing_errors = 0 + for rea_raw in tqdm(enumerator): + try: + rea = smiles(rea_raw.strip('\n')) if input_ext == ".smi" else rea_raw + except: + parsing_errors += 1 + print("Error", parsing_errors, rea_raw) + continue + try: + rea_mapped = remove_reagents_and_map(rea) + except: + parsing_errors += 1 + print("Error for,", rea) + continue + if rea_mapped: + rea_output = format(rea, "m") + "\n" if out_ext == ".smi" else rea + output_file.write(rea_output) + else: + mapping_errors += 1 + + input_file.close() + output_file.close() + + if mapping_errors: + print(mapping_errors, "reactions couldn't be mapped") diff --git a/SynTool/chem/data/standardizer.py b/SynTool/chem/data/standardizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac41e1914570449db0b3769e9c7de9c4e6e3b96 --- /dev/null +++ b/SynTool/chem/data/standardizer.py @@ -0,0 +1,604 @@ +############################################################################# +# Code issued from https://github.com/Laboratoire-de-Chemoinformatique/Reaction_Data_Cleaning +# Reaction_Data_Cleaning/scripts/standardizer.py +# version as it from commit 793475e54d8b2c7f714165a61e4eb439435d7d92 +# DOI 10.1002/minf.202100119 +############################################################################# +# Chemical reactions data curation best practices +# including optimized RDTool +############################################################################# +# GNU LGPL https://www.gnu.org/licenses/lgpl-3.0.en.html +############################################################################# +# Corresponding Authors: Timur Madzhidov and Alexandre Varnek +# Corresponding Authors' emails: tmadzhidov@gmail.com and varnek@unistra.fr +# Main contributors: Arkadii Lin, Natalia Duybankova, Ramil Nugmanov, Rail Suleymanov and Timur Madzhidov +# Copyright: Copyright 2020, +# MaDeSmart, Machine Design of Small Molecules by AI +# VLAIO project HBC.2018.2287 +# Credits: Kazan Federal University, Russia +# University of Strasbourg, France +# University of Linz, Austria +# University of Leuven, Belgium +# Janssen Pharmaceutica N.V., Beerse, Belgium +# Rail Suleymanov, Arcadia, St. Petersburg, Russia +# License: GNU LGPL https://www.gnu.org/licenses/lgpl-3.0.en.html +# Version: 00.02 +############################################################################# + +from CGRtools.files import RDFRead, RDFWrite, SDFWrite, SDFRead, SMILESRead +from CGRtools.containers import MoleculeContainer, ReactionContainer +import logging +from ordered_set import OrderedSet +import os +import io +import pathlib +from pathlib import PurePosixPath + + +class Standardizer: + def __init__(self, skip_errors=False, log_file=None, keep_unbalanced_ions=False, id_tag='Reaction_ID', + action_on_isotopes=0, keep_reagents=False, logger=None, ignore_mapping=False, jvm_path=None, + rdkit_dearomatization=False, remove_unchanged_parts=True, skip_tautomerize=True, + jchem_path=None, add_reagents_to_reactants=False) -> None: + if logger is None: + self.logger = self._config_log(log_file, logger_name='logger') + else: + self.logger = logger + self._skip_errors = skip_errors + self._keep_unbalanced_ions = keep_unbalanced_ions + self._id_tag = id_tag + self._action_on_isotopes = action_on_isotopes + self._keep_reagents = keep_reagents + self._ignore_mapping = ignore_mapping + self._remove_unchanged_parts_flag = remove_unchanged_parts + self._skip_tautomerize = skip_tautomerize + self._dearomatize_by_rdkit = rdkit_dearomatization + self._reagents_to_reactants = add_reagents_to_reactants + if not skip_tautomerize: + if jvm_path: + os.environ['JDK_HOME'] = jvm_path + os.environ['JAVA_HOME'] = jvm_path + os.environ['PATH'] += f';{PurePosixPath(jvm_path).joinpath("bin").joinpath("server")};' \ + f'{PurePosixPath(jvm_path).joinpath("bin").joinpath("server")};' + if jchem_path: + import jnius_config + jnius_config.add_classpath(jchem_path) + from jnius import autoclass + Standardizer = autoclass('chemaxon.standardizer.Standardizer') + self._Molecule = autoclass('chemaxon.struc.Molecule') + self._MolHandler = autoclass('chemaxon.util.MolHandler') + self._standardizer = Standardizer('tautomerize') + + def standardize_file(self, input_file=None) -> OrderedSet: + """ + Standardize a set of reactions in a file. Returns an ordered set of ReactionContainer objects passed the + standardization protocol. + :param input_file: str + :return: OrderedSet + """ + if pathlib.Path(input_file).suffix == '.rdf': + data = self._read_RDF(input_file) + elif pathlib.Path(input_file).suffix == '.smi' or pathlib.Path(input_file).suffix == '.smiles': + data = self._read_SMILES(input_file) + else: + raise ValueError('Data format is not recognized!') + + print("{0} reactions passed..".format(len(data))) + return data + + def _read_RDF(self, input_file) -> OrderedSet: + """ + Reads an RDF file. Returns an ordered set of ReactionContainer objects passed the standardization protocol. + :param input_file: str + :return: OrderedSet + """ + data = OrderedSet() + self.logger.info('Start..') + with RDFRead(input_file, ignore=self._ignore_mapping, store_log=True, remap=self._ignore_mapping) as ifile, \ + open(input_file) as meta_searcher: + for reaction in ifile._data: + if isinstance(reaction, tuple): + meta_searcher.seek(reaction.position) + flag = False + for line in meta_searcher: + if flag and '$RFMT' in line: + self.logger.critical(f'Reaction id extraction problem rised for the reaction ' + f'#{reaction.number + 1}: a reaction id was expected but $RFMT line ' + f'was found!') + if flag: + self.logger.critical(f'Reaction {line.strip().split()[1]}: Parser has returned an error ' + f'message\n{reaction.log}') + break + elif '$RFMT' in line: + self.logger.critical(f'Reaction #{reaction.number + 1} has no reaction id!') + elif f'$DTYPE {self._id_tag}' in line: + flag = True + continue + standardized_reaction = self.standardize(reaction) + if standardized_reaction: + if standardized_reaction not in data: + data.add(standardized_reaction) + else: + i = data.index(standardized_reaction) + if 'Extraction_IDs' not in data[i].meta: + data[i].meta['Extraction_IDs'] = '' + data[i].meta['Extraction_IDs'] = ','.join(data[i].meta['Extraction_IDs'].split(',') + + [reaction.meta[self._id_tag]]) + self.logger.info('Reaction {0} is a duplicate of the reaction {1}..' + .format(reaction.meta[self._id_tag], data[i].meta[self._id_tag])) + return data + + def _read_SMILES(self, input_file) -> OrderedSet: + """ + Reads a SMILES file. Returns an ordered set of ReactionContainer objects passed the standardization protocol. + :param input_file: str + :return: OrderedSet + """ + data = OrderedSet() + self.logger.info('Start..') + with SMILESRead(input_file, ignore=True, store_log=True, remap=self._ignore_mapping, header=True) as ifile, \ + open(input_file) as meta_searcher: + id_tag_position = meta_searcher.readline().strip().split().index(self._id_tag) + if id_tag_position is None or id_tag_position == 0: + self.logger.critical(f'No reaction ID tag was found in the header!') + raise ValueError(f'No reaction ID tag was found in the header!') + for reaction in ifile._data: + if isinstance(reaction, tuple): + meta_searcher.seek(reaction.position) + line = meta_searcher.readline().strip().split() + if len(line) <= id_tag_position: + self.logger.critical(f'No reaction ID tag was found in line {reaction.number}!') + raise ValueError(f'No reaction ID tag was found in line {reaction.number}!') + r_id = line[id_tag_position] + self.logger.critical(f'Reaction {r_id}: Parser has returned an error message\n{reaction.log}') + continue + + standardized_reaction = self.standardize(reaction) + if standardized_reaction: + if standardized_reaction not in data: + data.add(standardized_reaction) + else: + i = data.index(standardized_reaction) + if 'Extraction_IDs' not in data[i].meta: + data[i].meta['Extraction_IDs'] = '' + data[i].meta['Extraction_IDs'] = ','.join(data[i].meta['Extraction_IDs'].split(',') + + [reaction.meta[self._id_tag]]) + self.logger.info('Reaction {0} is a duplicate of the reaction {1}..' + .format(reaction.meta[self._id_tag], data[i].meta[self._id_tag])) + return data + + def standardize(self, reaction: ReactionContainer) -> ReactionContainer: + """ + Standardization protocol: transform functional groups, kekulize, remove explicit hydrogens, + check for radicals (remove if something was found), check for isotopes, regroup ions (if the total charge + of reactants and/or products is not zero, and the 'keep_unbalanced_ions' option is False which is by default, + such reactions are removed; if the 'keep_unbalanced_ions' option is set True, they are kept), check valences + (remove if something is wrong), aromatize (thiele method), fix mapping (for symmetric functional groups) if + such is in, remove unchanged parts. + :param reaction: ReactionContainer + :return: ReactionContainer + """ + self.logger.info('Reaction {0}..'.format(reaction.meta[self._id_tag])) + try: + reaction.standardize() + except: + self.logger.exception( + 'Reaction {0}: Cannot standardize functional groups..'.format(reaction.meta[self._id_tag])) + if not self._skip_errors: + raise Exception( + 'Reaction {0}: Cannot standardize functional groups..'.format(reaction.meta[self._id_tag])) + else: + return + try: + reaction.kekule() + except: + self.logger.exception('Reaction {0}: Cannot kekulize..'.format(reaction.meta[self._id_tag])) + if not self._skip_errors: + raise Exception('Reaction {0}: Cannot kekulize..'.format(reaction.meta[self._id_tag])) + else: + return + try: + if self._check_valence(reaction): + self.logger.info( + 'Reaction {0}: Bad valence: {1}'.format(reaction.meta[self._id_tag], reaction.meta['mistake'])) + return + except: + self.logger.exception('Reaction {0}: Cannot check valence..'.format(reaction.meta[self._id_tag])) + if not self._skip_errors: + self.logger.critical('Stop the algorithm!') + raise Exception('Reaction {0}: Cannot check valence..'.format(reaction.meta[self._id_tag])) + else: + return + try: + if not self._skip_tautomerize: + reaction = self._tautomerize(reaction) + except: + self.logger.exception('Reaction {0}: Cannot tautomerize..'.format(reaction.meta[self._id_tag])) + if not self._skip_errors: + raise Exception('Reaction {0}: Cannot tautomerize..'.format(reaction.meta[self._id_tag])) + else: + return + try: + reaction.implicify_hydrogens() + except: + self.logger.exception( + 'Reaction {0}: Cannot remove explicit hydrogens..'.format(reaction.meta[self._id_tag])) + if not self._skip_errors: + raise Exception('Reaction {0}: Cannot remove explicit hydrogens..'.format(reaction.meta[self._id_tag])) + else: + return + try: + if self._check_radicals(reaction): + self.logger.info('Reaction {0}: Radicals were found..'.format(reaction.meta[self._id_tag])) + return + except: + self.logger.exception('Reaction {0}: Cannot check radicals..'.format(reaction.meta[self._id_tag])) + if not self._skip_errors: + raise Exception('Reaction {0}: Cannot check radicals..'.format(reaction.meta[self._id_tag])) + else: + return + try: + if self._action_on_isotopes == 1 and self._check_isotopes(reaction): + self.logger.info('Reaction {0}: Isotopes were found..'.format(reaction.meta[self._id_tag])) + return + elif self._action_on_isotopes == 2 and self._check_isotopes(reaction): + reaction.clean_isotopes() + self.logger.info('Reaction {0}: Isotopes were removed but the reaction was kept..'.format( + reaction.meta[self._id_tag])) + except: + self.logger.exception('Reaction {0}: Cannot check for isotopes..'.format(reaction.meta[self._id_tag])) + if not self._skip_errors: + raise Exception('Reaction {0}: Cannot check for isotopes..'.format(reaction.meta[self._id_tag])) + else: + return + try: + reaction, return_code = self._split_ions(reaction) + if return_code == 1: + self.logger.info('Reaction {0}: Ions were split..'.format(reaction.meta[self._id_tag])) + elif return_code == 2: + self.logger.info('Reaction {0}: Ions were split but the reaction is imbalanced..'.format( + reaction.meta[self._id_tag])) + if not self._keep_unbalanced_ions: + return + except: + self.logger.exception('Reaction {0}: Cannot group ions..'.format(reaction.meta[self._id_tag])) + if not self._skip_errors: + raise Exception('Reaction {0}: Cannot group ions..'.format(reaction.meta[self._id_tag])) + else: + return + try: + reaction.thiele() + except: + self.logger.exception('Reaction {0}: Cannot aromatize..'.format(reaction.meta[self._id_tag])) + if not self._skip_errors: + raise Exception('Reaction {0}: Cannot aromatize..'.format(reaction.meta[self._id_tag])) + else: + return + try: + reaction.fix_mapping() + except: + self.logger.exception('Reaction {0}: Cannot fix mapping..'.format(reaction.meta[self._id_tag])) + if not self._skip_errors: + raise Exception('Reaction {0}: Cannot fix mapping..'.format(reaction.meta[self._id_tag])) + else: + return + try: + if self._remove_unchanged_parts_flag: + reaction = self._remove_unchanged_parts(reaction) + if not reaction.reactants and reaction.products: + self.logger.info('Reaction {0}: Reactants are empty..'.format(reaction.meta[self._id_tag])) + return + if not reaction.products and reaction.reactants: + self.logger.info('Reaction {0}: Products are empty..'.format(reaction.meta[self._id_tag])) + return + if not reaction.reactants and not reaction.products: + self.logger.exception( + 'Reaction {0}: Cannot remove unchanged parts or the reaction is empty..'.format( + reaction.meta[self._id_tag])) + return + except: + self.logger.exception('Reaction {0}: Cannot remove unchanged parts or the reaction is empty..'.format( + reaction.meta[self._id_tag])) + if not self._skip_errors: + raise Exception('Reaction {0}: Cannot remove unchanged parts or the reaction is empty..'.format( + reaction.meta[self._id_tag])) + else: + return + self.logger.debug('Reaction {0} is done..'.format(reaction.meta[self._id_tag])) + return reaction + + def write(self, output_file: str, data: OrderedSet) -> None: + """ + Dump a set of reactions. + :param data: OrderedSet + :param output_file: str + :return: None + """ + with RDFWrite(output_file) as out: + for r in data: + out.write(r) + + def _check_valence(self, reaction: ReactionContainer) -> bool: + """ + Checks valences. + :param reaction: ReactionContainer + :return: bool + """ + mistakes = [] + for molecule in (reaction.reactants + reaction.products + reaction.reagents): + valence_mistakes = molecule.check_valence() + if valence_mistakes: + mistakes.append(("|".join([str(num) for num in valence_mistakes]), + "|".join([str(molecule.atom(n)) for n in valence_mistakes]), str(molecule))) + if mistakes: + message = ",".join([f'{atom_nums} at {atoms} in {smiles}' for atom_nums, atoms, smiles in mistakes]) + reaction.meta['mistake'] = f'Valence mistake: {message}' + return True + return False + + def _config_log(self, log_file: str, logger_name: str): + logger = logging.getLogger(logger_name) + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter(fmt='%(asctime)s: %(message)s', datefmt='%d/%m/%Y %H:%M:%S') + logger.handlers.clear() + fileHandler = logging.FileHandler(filename=log_file, mode='w') + fileHandler.setFormatter(formatter) + fileHandler.setLevel(logging.DEBUG) + logger.addHandler(fileHandler) + # logging.basicConfig(filename=log_file, level=logging.info, filemode='w', format='%(asctime)s: %(message)s', + # datefmt='%d/%m/%Y %H:%M:%S') + return logger + + def _check_radicals(self, reaction: ReactionContainer) -> bool: + """ + Checks radicals. + :param reaction: ReactionContainer + :return: bool + """ + for molecule in (reaction.reactants + reaction.products + reaction.reagents): + for n, atom in molecule.atoms(): + if atom.is_radical: + return True + return False + + def _calc_charge(self, molecule: MoleculeContainer) -> int: + """Computing charge of molecule. + :param: molecule: MoleculeContainer + :return: int + """ + return sum(molecule._charges.values()) + + def _group_ions(self, reaction: ReactionContainer): + """ + Ungroup molecules recorded as ions, regroup ions. Returns a tuple with the corresponding ReactionContainer and + return code as int (0 - nothing was changed, 1 - ions were regrouped, 2 - ions are unbalanced). + :param reaction: current reaction + :return: tuple[ReactionContainer, int] + """ + meta = reaction.meta + reaction_parts = [] + return_codes = [] + for molecules in (reaction.reactants, reaction.reagents, reaction.products): + divided_molecules = [x for m in molecules for x in m.split('.')] + + if len(divided_molecules) == 0: + reaction_parts.append(()) + continue + elif len(divided_molecules) == 1 and self._calc_charge(divided_molecules[0]) == 0: + return_codes.append(0) + reaction_parts.append(molecules) + continue + elif len(divided_molecules) == 1: + return_codes.append(2) + reaction_parts.append(molecules) + continue + + new_molecules = [] + cations, anions, ions = [], [], [] + total_charge = 0 + for molecule in divided_molecules: + mol_charge = self._calc_charge(molecule) + total_charge += mol_charge + if mol_charge == 0: + new_molecules.append(molecule) + elif mol_charge > 0: + cations.append((mol_charge, molecule)) + ions.append((mol_charge, molecule)) + else: + anions.append((mol_charge, molecule)) + ions.append((mol_charge, molecule)) + + if len(cations) == 0 and len(anions) == 0: + return_codes.append(0) + reaction_parts.append(tuple(new_molecules)) + continue + elif total_charge != 0: + return_codes.append(2) + reaction_parts.append(tuple(divided_molecules)) + continue + else: + salt = MoleculeContainer() + for ion_charge, ion in ions: + salt = salt.union(ion) + total_charge += ion_charge + if total_charge == 0: + new_molecules.append(salt) + salt = MoleculeContainer() + if total_charge != 0: + new_molecules.append(salt) + return_codes.append(2) + reaction_parts.append(tuple(new_molecules)) + else: + return_codes.append(1) + reaction_parts.append(tuple(new_molecules)) + return ReactionContainer(reactants=reaction_parts[0], reagents=reaction_parts[1], products=reaction_parts[2], + meta=meta), max(return_codes) + + def _split_ions(self, reaction: ReactionContainer): + """ + Split ions in a reaction. Returns a tuple with the corresponding ReactionContainer and + a return code as int (0 - nothing was changed, 1 - ions were split, 2 - ions were split but the reaction + is imbalanced). + :param reaction: current reaction + :return: tuple[ReactionContainer, int] + """ + meta = reaction.meta + reaction_parts = [] + return_codes = [] + for molecules in (reaction.reactants, reaction.reagents, reaction.products): + divided_molecules = [x for m in molecules for x in m.split('.')] + + total_charge = 0 + ions_present = False + for molecule in divided_molecules: + mol_charge = self._calc_charge(molecule) + total_charge += mol_charge + if mol_charge != 0: + ions_present = True + + if ions_present and total_charge: + return_codes.append(2) + elif ions_present: + return_codes.append(1) + else: + return_codes.append(0) + + reaction_parts.append(tuple(divided_molecules)) + + return ReactionContainer(reactants=reaction_parts[0], reagents=reaction_parts[1], products=reaction_parts[2], + meta=meta), max(return_codes) + + def _remove_unchanged_parts(self, reaction: ReactionContainer) -> ReactionContainer: + """ + Ungroup molecules, remove unchanged parts from reactants and products. + :param reaction: current reaction + :return: ReactionContainer + """ + meta = reaction.meta + new_reactants = [m for m in reaction.reactants] + new_reagents = [m for m in reaction.reagents] + if self._reagents_to_reactants: + new_reactants.extend(new_reagents) + new_reagents = [] + reactants = new_reactants.copy() + new_products = [m for m in reaction.products] + + for reactant in reactants: + if reactant in new_products: + new_reagents.append(reactant) + new_reactants.remove(reactant) + new_products.remove(reactant) + if not self._keep_reagents: + new_reagents = [] + return ReactionContainer(reactants=tuple(new_reactants), reagents=tuple(new_reagents), + products=tuple(new_products), meta=meta) + + def _check_isotopes(self, reaction: ReactionContainer) -> bool: + for molecules in (reaction.reactants, reaction.products): + for molecule in molecules: + for _, atom in molecule.atoms(): + if atom.isotope: + return True + return False + + def _tautomerize(self, reaction: ReactionContainer) -> ReactionContainer: + """ + Perform ChemAxon tautomerization. + :param reaction: reaction that needs to be tautomerized + :return: ReactionContainer + """ + new_molecules = [] + for part in [reaction.reactants, reaction.reagents, reaction.products]: + tmp = [] + for mol in part: + with io.StringIO() as f, SDFWrite(f) as i: + i.write(mol) + sdf = f.getvalue() + mol_handler = self._MolHandler(sdf) + mol_handler.clean(True, '2') + molecule = mol_handler.getMolecule() + self._standardizer.standardize(molecule) + new_mol_handler = self._MolHandler(molecule) + new_sdf = new_mol_handler.toFormat('SDF') + with io.StringIO('\n ' + new_sdf.strip()) as f, SDFRead(f, remap=False) as i: + new_mol = next(i) + tmp.append(new_mol) + new_molecules.append(tmp) + return ReactionContainer(reactants=tuple(new_molecules[0]), reagents=tuple(new_molecules[1]), + products=tuple(new_molecules[2]), meta=reaction.meta) + + # def _dearomatize_by_RDKit(self, reaction: ReactionContainer) -> ReactionContainer: + # """ + # Dearomatizes by RDKit (needs in case of some mappers, such as RXNMapper). + # :param reaction: ReactionContainer + # :return: ReactionContainer + # """ + # with io.StringIO() as f, RDFWrite(f) as i: + # i.write(reaction) + # s = '\n'.join(f.getvalue().split('\n')[3:]) + # rxn = rdChemReactions.ReactionFromRxnBlock(s) + # reactants, reagents, products = [], [], [] + # for mol in rxn.GetReactants(): + # try: + # Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_KEKULIZE, catchErrors=True) + # except Chem.rdchem.KekulizeException: + # return reaction + # with io.StringIO(Chem.MolToMolBlock(mol)) as f2, SDFRead(f2, remap=False) as sdf_i: + # reactants.append(next(sdf_i)) + # for mol in rxn.GetAgents(): + # try: + # Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_KEKULIZE, catchErrors=True) + # except Chem.rdchem.KekulizeException: + # return reaction + # with io.StringIO(Chem.MolToMolBlock(mol)) as f2, SDFRead(f2, remap=False) as sdf_i: + # reagents.append(next(sdf_i)) + # for mol in rxn.GetProducts(): + # try: + # Chem.SanitizeMol(mol, Chem.SanitizeFlags.SANITIZE_KEKULIZE, catchErrors=True) + # except Chem.rdchem.KekulizeException: + # return reaction + # with io.StringIO(Chem.MolToMolBlock(mol)) as f2, SDFRead(f2, remap=False) as sdf_i: + # products.append(next(sdf_i)) + # + # new_reaction = ReactionContainer(reactants=tuple(reactants), reagents=tuple(reagents), products=tuple(products), + # meta=reaction.meta) + # + # return new_reaction + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description="This is a tool for reaction standardization.", + epilog="Arkadii Lin, Strasbourg/Kazan 2020", prog="Standardizer") + parser.add_argument("-i", "--input", type=str, help="Input RDF file.") + parser.add_argument("-o", "--output", type=str, help="Output RDF file.") + parser.add_argument("-id", "--idTag", default='Reaction_ID', type=str, help="ID tag in the RDF file.") + parser.add_argument("--skipErrors", action="store_true", help="Skip errors.") + parser.add_argument("--keep_unbalanced_ions", action="store_true", help="Will keep reactions with unbalanced ions.") + parser.add_argument("--action_on_isotopes", type=int, default=0, help="Action performed if an isotope is " + "found: 0 - to ignore isotopes; " + "1 - to remove reactions with isotopes; " + "2 - to clear isotopes' labels.") + parser.add_argument("--keep_reagents", action="store_true", help="Will keep reagents from the reaction.") + parser.add_argument("--add_reagents", action="store_true", help="Will add the given reagents to reactants.") + parser.add_argument("--ignore_mapping", action="store_true", help="Will ignore the initial mapping in the file.") + parser.add_argument("--keep_unchanged_parts", action="store_true", help="Will keep unchanged parts in a reaction.") + parser.add_argument("--logFile", type=str, default='logFile.txt', help="Log file name.") + parser.add_argument("--skip_tautomerize", action="store_true", help="Will skip generation of the major tautomer.") + parser.add_argument("--rdkit_dearomatization", action="store_true", help="Will kekulize the reaction using RDKit " + "facilities.") + parser.add_argument("--jvm_path", type=str, + help="JVM path (e.g. C:\\Program Files\\Java\\jdk-13.0.2).") + parser.add_argument("--jchem_path", type=str, help="JChem path (e.g. C:\\Users\\user\\JChemSuite\\lib\\jchem.jar).") + args = parser.parse_args() + + standardizer = Standardizer(skip_errors=args.skipErrors, log_file=args.logFile, + keep_unbalanced_ions=args.keep_unbalanced_ions, id_tag=args.idTag, + action_on_isotopes=args.action_on_isotopes, keep_reagents=args.keep_reagents, + ignore_mapping=args.ignore_mapping, skip_tautomerize=args.skip_tautomerize, + remove_unchanged_parts=(not args.keep_unchanged_parts), jvm_path=args.jvm_path, + jchem_path=args.jchem_path, rdkit_dearomatization=args.rdkit_dearomatization, + add_reagents_to_reactants=args.add_reagents) + data = standardizer.standardize_file(input_file=args.input) + standardizer.write(output_file=args.output, data=data) diff --git a/SynTool/chem/reaction.py b/SynTool/chem/reaction.py new file mode 100755 index 0000000000000000000000000000000000000000..77f789ff880b4181c2aa12684c98ed7856b3ac67 --- /dev/null +++ b/SynTool/chem/reaction.py @@ -0,0 +1,107 @@ +""" +Module containing classes and functions for manipulating reactions and reaction rules +""" + +from CGRtools.reactor import Reactor +from CGRtools.containers import MoleculeContainer, ReactionContainer +from CGRtools.exceptions import InvalidAromaticRing + + +class Reaction(ReactionContainer): + """ + Reaction class can be used for a general representation of reaction for different chemoinformatics Python packages + """ + + def __init__(self, *args, **kwargs): + """ + Initializes the reaction object. + """ + super().__init__(*args, **kwargs) + + +def add_small_mols(big_mol, small_molecules=None): + """ + The function takes a molecule and returns a list of modified molecules where each small molecule has been added to + the big molecule. + + :param big_mol: A molecule + :param small_molecules: A list of small molecules that need to be added to the molecule + :return: Returns a list of molecules. + """ + if small_molecules: + tmp_mol = big_mol.copy() + transition_mapping = {} + for small_mol in small_molecules: + + for n, atom in small_mol.atoms(): + new_number = tmp_mol.add_atom(atom.atomic_symbol) + transition_mapping[n] = new_number + + for atom, neighbor, bond in small_mol.bonds(): + tmp_mol.add_bond(transition_mapping[atom], transition_mapping[neighbor], bond) + + transition_mapping = {} + return tmp_mol.split() + else: + return [big_mol] + + +def apply_reaction_rule( + molecule: MoleculeContainer, + reaction_rule: Reactor, + sort_reactions: bool = False, + top_reactions_num: int = 3, + validate_products: bool = True, + rebuild_with_cgr: bool = False, +) -> list[MoleculeContainer]: + """ + The function applies a reaction rule to a given molecule. + + :param rebuild_with_cgr: + :param validate_products: + :param sort_reactions: + :param top_reactions_num: + :param molecule: A MoleculeContainer object representing the molecule on which the reaction rule will be applied + :type molecule: MoleculeContainer + :param reaction_rule: The reaction_rule is an instance of the Reactor class. It represents a reaction rule that + can be applied to a molecule + :type reaction_rule: Reactor + """ + + reactants = add_small_mols(molecule, small_molecules=False) + + try: + if sort_reactions: + unsorted_reactions = list(reaction_rule(reactants)) + sorted_reactions = sorted( + unsorted_reactions, + key=lambda react: len(list(filter(lambda mol: len(mol) > 6, react.products))), + reverse=True + ) + reactions = sorted_reactions[:top_reactions_num] # Take top-N reactions from reactor + else: + reactions = [] + for reaction in reaction_rule(reactants): + reactions.append(reaction) + if len(reactions) == top_reactions_num: + break + except IndexError: + reactions = [] + + for reaction in reactions: + if rebuild_with_cgr: + cgr = reaction.compose() + products = cgr.decompose()[1].split() + else: + products = reaction.products + products = [mol for mol in products if len(mol) > 0] + if validate_products: + for molecule in products: + try: + molecule.kekule() + if molecule.check_valence(): + yield None + molecule.thiele() + except InvalidAromaticRing: + yield None + yield products diff --git a/SynTool/chem/reaction_rules/__init__.py b/SynTool/chem/reaction_rules/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SynTool/chem/reaction_rules/__pycache__/__init__.cpython-310.pyc b/SynTool/chem/reaction_rules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..979dbb4bef420d97b18a5a953c1cd4a5a645f1cf Binary files /dev/null and b/SynTool/chem/reaction_rules/__pycache__/__init__.cpython-310.pyc differ diff --git a/SynTool/chem/reaction_rules/__pycache__/extraction.cpython-310.pyc b/SynTool/chem/reaction_rules/__pycache__/extraction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97e635f0947372978b3298bd6f3f334536a9876b Binary files /dev/null and b/SynTool/chem/reaction_rules/__pycache__/extraction.cpython-310.pyc differ diff --git a/SynTool/chem/reaction_rules/extraction.py b/SynTool/chem/reaction_rules/extraction.py new file mode 100755 index 0000000000000000000000000000000000000000..2a03beda431c6f61e3db531d5894ab687847185b --- /dev/null +++ b/SynTool/chem/reaction_rules/extraction.py @@ -0,0 +1,679 @@ +""" +Module containing functions with fixed protocol for reaction rules extraction +""" +import logging +import pickle +from collections import defaultdict +from itertools import islice +from pathlib import Path +from typing import List, Union, Tuple, IO, Dict, Set, Iterable, Any +from os.path import splitext + + +import ray +from CGRtools.containers import MoleculeContainer, QueryContainer, ReactionContainer +from CGRtools.exceptions import InvalidAromaticRing +from CGRtools.reactor import Reactor +from tqdm.auto import tqdm + +from SynTool.chem.utils import reverse_reaction +from SynTool.utils.config import RuleExtractionConfig +from SynTool.utils.files import ReactionReader + + +def extract_rules_from_reactions( + config: RuleExtractionConfig, + reaction_file: str, + rules_file_name: str = 'reaction_rules.pickle', + num_cpus: int = 1, + batch_size: int = 10, +) -> None: + """ + Extracts reaction rules from a set of reactions based on the given configuration. + + This function initializes a Ray environment for distributed computing and processes each reaction + in the provided reaction database to extract reaction rules. It handles the reactions in batches, + parallelize the rule extraction process. Extracted rules are written to RDF files and their statistics + are recorded. The function also sorts the rules based on their popularity and saves the sorted rules. + + :param config: Configuration settings for rule extraction, including file paths, batch size, and other parameters. + :param reaction_file: Path to the file containing reaction database. + :param rules_file_name: Name of the file to store the extracted rules. + :param num_cpus: Number of CPU cores to use for processing. Defaults to 1. + :param batch_size: Number of reactions to process in each batch. Defaults to 10. + + :return: None + """ + + # read files + reaction_file = Path(reaction_file).resolve(strict=True) + + ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR) + + rules_file_name, _ = splitext(rules_file_name) + with ReactionReader(reaction_file) as reactions: + pbar = tqdm(reactions, disable=False) + + futures = {} + batch = [] + max_concurrent_batches = num_cpus + + extracted_rules_and_statistics = defaultdict(list) + for index, reaction in enumerate(reactions): + batch.append((index, reaction)) + if len(batch) == batch_size: + future = process_reaction_batch.remote(batch, config) + futures[future] = None + batch = [] + + while len(futures) >= max_concurrent_batches: + process_completed_batches(futures, extracted_rules_and_statistics, pbar, batch_size) + + if batch: + remaining_size = len(batch) + future = process_reaction_batch.remote(batch, config) + futures[future] = None + + while futures: + process_completed_batches(futures, extracted_rules_and_statistics, pbar, remaining_size) + + pbar.close() + + sorted_rules = sort_rules( + extracted_rules_and_statistics, + min_popularity=config.min_popularity, + single_reactant_only=config.single_reactant_only, + ) + + with open(f"{rules_file_name}.pickle", "wb") as statistics_file: + pickle.dump(sorted_rules, statistics_file) + print(f'Number of extracted reaction rules: {len(sorted_rules)}') + + ray.shutdown() + + +@ray.remote +def process_reaction_batch( + batch: List[Tuple[int, ReactionContainer]], config: RuleExtractionConfig +) -> list[tuple[int, list[ReactionContainer]]]: + """ + Processes a batch of reactions to extract reaction rules based on the given configuration. + + This function operates as a remote task in a distributed system using Ray. It takes a batch of reactions, + where each reaction is paired with an index. For each reaction in the batch, it extracts reaction rules + as specified by the configuration object. The extracted rules for each reaction are then returned along + with the corresponding index. + + :param batch: A list where each element is a tuple containing an index (int) and a ReactionContainer object. + The index is typically used to keep track of the reaction's position in a larger dataset. + :type batch: List[Tuple[int, ReactionContainer]] + + :param config: An instance of ExtractRuleConfig that provides settings and parameters for the rule extraction process. + :type config: RuleExtractionConfig + + :return: A list where each element is a tuple. The first element of the tuple is an index (int), and the second + is a list of ReactionContainer objects representing the extracted rules for the corresponding reaction. + :rtype: list[tuple[int, list[ReactionContainer]]] + + This function is intended to be used in a distributed manner with Ray to parallelize the rule extraction + process across multiple reactions. + """ + processed_batch = [] + for index, reaction in batch: + try: + extracted_rules = extract_rules(config, reaction) + processed_batch.append((index, extracted_rules)) + except: + continue + return processed_batch + + +def process_completed_batches( + futures: dict, + rules_statistics: Dict[ReactionContainer, List[int]], + pbar: tqdm, + batch_size: int, +) -> None: + """ + Processes completed batches of reactions, updating the rules statistics and writing rules to a file. + + This function waits for the completion of a batch of reactions processed in parallel (using Ray), + updates the statistics for each extracted rule, and writes the rules to a result file if they are new. + It also updates the progress bar with the size of the processed batch. + + :param futures: A dictionary of futures representing ongoing batch processing tasks. + :type futures: dict + + :param rules_statistics: A dictionary to keep track of statistics for each rule. + :type rules_statistics: Dict[ReactionContainer, List[int]] + + :param pbar: A tqdm progress bar instance for updating the progress of batch processing. + :type pbar: tqdm + + :param batch_size: The number of reactions processed in each batch. + :type batch_size: int + + :return: None + """ + done, _ = ray.wait(list(futures.keys()), num_returns=1) + completed_batch = ray.get(done[0]) + + for index, extracted_rules in completed_batch: + for rule in extracted_rules: + prev_stats_len = len(rules_statistics) + rules_statistics[rule].append(index) + if len(rules_statistics) != prev_stats_len: + rule.meta["first_reaction_index"] = index + + del futures[done[0]] + pbar.update(batch_size) + + +def extract_rules( + config: RuleExtractionConfig, reaction: ReactionContainer +) -> list[ReactionContainer]: + """ + Extracts reaction rules from a given reaction based on the specified configuration. + + :param config: An instance of ExtractRuleConfig, which contains various configuration settings + for rule extraction, such as whether to include multicenter rules, functional groups, + ring structures, leaving and incoming groups, etc. + :param reaction: The reaction object (ReactionContainer) from which to extract rules. The reaction + object represents a chemical reaction with specified reactants, products, and possibly reagents. + :return: A list of ReactionContainer objects, each representing a distinct reaction rule. If + config.multicenter_rules is True, a single rule encompassing all reaction centers is returned. + Otherwise, separate rules for each reaction center are extracted, up to a maximum of 15 distinct centers. + """ + if config.multicenter_rules: + # Extract a single rule encompassing all reaction centers + return [create_rule(config, reaction)] + else: + # Extract separate rules for each distinct reaction center + distinct_rules = set() + for center_reaction in islice(reaction.enumerate_centers(), 15): + single_rule = create_rule(config, center_reaction) + distinct_rules.add(single_rule) + return list(distinct_rules) + + +def create_rule( + config: RuleExtractionConfig, reaction: ReactionContainer +) -> ReactionContainer: + """ + Creates a reaction rule from a given reaction based on the specified configuration. + + :param config: An instance of ExtractRuleConfig, containing various settings that determine how + the rule is created, such as environmental atom count, inclusion of functional groups, + rings, leaving and incoming groups, and other parameters. + :param reaction: The reaction object (ReactionContainer) from which to create the rule. This object + represents a chemical reaction with specified reactants, products, and possibly reagents. + :return: A ReactionContainer object representing the extracted reaction rule. This rule includes + various elements of the reaction as specified by the configuration, such as reaction centers, + environmental atoms, functional groups, and others. + + The function processes the reaction to create a rule that matches the configuration settings. It handles + the inclusion of environmental atoms, functional groups, ring structures, and leaving and incoming groups. + It also constructs substructures for reactants, products, and reagents, and cleans molecule representations + if required. Optionally, it validates the rule using a reactor. + """ + cgr = ~reaction + center_atoms = set(cgr.center_atoms) + + # Add atoms of reaction environment based on config settings + center_atoms = add_environment_atoms( + cgr, center_atoms, config.environment_atom_count + ) + + # Include functional groups in the rule if specified in config + if config.include_func_groups: + rule_atoms = add_functional_groups( + reaction, center_atoms, config.func_groups_list + ) + else: + rule_atoms = center_atoms.copy() + + # Include ring structures in the rule if specified in config + if config.include_rings: + rule_atoms = add_ring_structures( + cgr, + rule_atoms, + ) + + # Add leaving and incoming groups to the rule based on config settings + rule_atoms, meta_debug = add_leaving_incoming_groups( + reaction, rule_atoms, config.keep_leaving_groups, config.keep_incoming_groups + ) + + # Create substructures for reactants, products, and reagents + ( + reactant_substructures, + product_substructures, + reagents, + ) = create_substructures_and_reagents( + reaction, rule_atoms, config.as_query_container, config.keep_reagents + ) + + # Clean atom marks in the molecules if they are being converted to query containers + if config.as_query_container: + reactant_substructures = clean_molecules( + reactant_substructures, + reaction.reactants, + center_atoms, + config.atom_info_retention, + ) + product_substructures = clean_molecules( + product_substructures, + reaction.products, + center_atoms, + config.atom_info_retention, + ) + + # Assemble the final rule including metadata if specified + rule = assemble_final_rule( + reactant_substructures, + product_substructures, + reagents, + meta_debug, + config.keep_metadata, + reaction, + ) + + if config.reverse_rule: + rule = reverse_reaction(rule) + reaction = reverse_reaction(reaction) + + # Validate the rule using a reactor if validation is enabled in config + if config.reactor_validation: + if validate_rule(rule, reaction): + rule.meta["reactor_validation"] = "passed" + else: + rule.meta["reactor_validation"] = "failed" + + return rule + + +def add_environment_atoms(cgr, center_atoms, environment_atom_count): + """ + Adds environment atoms to the set of center atoms based on the specified depth. + + :param cgr: A complete graph representation of a reaction (ReactionContainer object). + :param center_atoms: A set of atom identifiers representing the center atoms of the reaction. + :param environment_atom_count: An integer specifying the depth of the environment around + the reaction center to be included. If it's 0, only the + reaction center is included. If it's 1, the first layer of + surrounding atoms is included, and so on. + :return: A set of atom identifiers including the center atoms and their environment atoms + up to the specified depth. If environment_atom_count is 0, the original set of + center atoms is returned unchanged. + """ + if environment_atom_count: + env_cgr = cgr.augmented_substructure(center_atoms, deep=environment_atom_count) + # Combine the original center atoms with the new environment atoms + return center_atoms | set(env_cgr) + + # If no environment is to be included, return the original center atoms + return center_atoms + + +def add_functional_groups(reaction, center_atoms, func_groups_list): + """ + Augments the set of rule atoms with functional groups if specified. + + :param reaction: The reaction object (ReactionContainer) from which molecules are extracted. + :param center_atoms: A set of atom identifiers representing the center atoms of the reaction. + :param func_groups_list: A list of functional group objects (MoleculeContainer or QueryContainer) + to be considered when including functional groups. These objects define + the structure of the functional groups to be included. + :return: A set of atom identifiers representing the rule atoms, including atoms from the + specified functional groups if include_func_groups is True. If include_func_groups + is False, the original set of center atoms is returned. + """ + rule_atoms = center_atoms.copy() + # Iterate over each molecule in the reaction + for molecule in reaction.molecules(): + # For each functional group specified in the list + for func_group in func_groups_list: + # Find mappings of the functional group in the molecule + for mapping in func_group.get_mapping(molecule): + # Remap the functional group based on the found mapping + func_group.remap(mapping) + # If the functional group intersects with center atoms, include it + if set(func_group.atoms_numbers) & center_atoms: + rule_atoms |= set(func_group.atoms_numbers) + # Reset the mapping to its original state for the next iteration + func_group.remap({v: k for k, v in mapping.items()}) + return rule_atoms + + +def add_ring_structures(cgr, rule_atoms): + """ + Appends ring structures to the set of rule atoms if they intersect with the reaction center atoms. + + :param cgr: A condensed graph representation of a reaction (CGRContainer object). + :param rule_atoms: A set of atom identifiers representing the center atoms of the reaction. + :return: A set of atom identifiers including the original rule atoms and the included ring structures. + """ + for ring in cgr.sssr: + # Check if the current ring intersects with the set of rule atoms + if set(ring) & rule_atoms: + # If the intersection exists, include all atoms in the ring to the rule atoms + rule_atoms |= set(ring) + return rule_atoms + + +def add_leaving_incoming_groups( + reaction, rule_atoms, keep_leaving_groups, keep_incoming_groups +): + """ + Identifies and includes leaving and incoming groups to the rule atoms based on specified flags. + + :param reaction: The reaction object (ReactionContainer) from which leaving and incoming groups are extracted. + :param rule_atoms: A set of atom identifiers representing the center atoms of the reaction. + :param keep_leaving_groups: A boolean flag indicating whether to include leaving groups in the rule. + :param keep_incoming_groups: A boolean flag indicating whether to include incoming groups in the rule. + :return: Updated set of rule atoms including leaving and incoming groups if specified, and metadata about added groups. + """ + meta_debug = {"leaving": set(), "incoming": set()} + + # Extract atoms from reactants and products + reactant_atoms = {atom for reactant in reaction.reactants for atom in reactant} + product_atoms = {atom for product in reaction.products for atom in product} + + # Identify leaving groups (reactant atoms not in products) + if keep_leaving_groups: + leaving_atoms = reactant_atoms - product_atoms + new_leaving_atoms = leaving_atoms - rule_atoms + # Include leaving atoms in the rule atoms + rule_atoms |= leaving_atoms + # Add leaving atoms to metadata + meta_debug["leaving"] |= new_leaving_atoms + + # Identify incoming groups (product atoms not in reactants) + if keep_incoming_groups: + incoming_atoms = product_atoms - reactant_atoms + new_incoming_atoms = incoming_atoms - rule_atoms + # Include incoming atoms in the rule atoms + rule_atoms |= incoming_atoms + # Add incoming atoms to metadata + meta_debug["incoming"] |= new_incoming_atoms + + return rule_atoms, meta_debug + + +def clean_molecules( + rule_molecules: Iterable[QueryContainer], + reaction_molecules: Iterable[MoleculeContainer], + reaction_center_atoms: Set[int], + atom_retention_details: Dict[str, Dict[str, bool]], +) -> List[QueryContainer]: + """ + Cleans rule molecules by removing specified information about atoms based on retention details provided. + + :param rule_molecules: A list of query container objects representing the rule molecules. + :param reaction_molecules: A list of molecule container objects involved in the reaction. + :param reaction_center_atoms: A set of integers representing atom numbers in the reaction center. + :param atom_retention_details: A dictionary specifying what atom information to retain or remove. + This dictionary should have two keys: "reaction_center" and "environment", + each mapping to another dictionary. The nested dictionaries should have + keys representing atom attributes (like "neighbors", "hybridization", + "implicit_hydrogens", "ring_sizes") and boolean values. A value of True + indicates that the corresponding attribute should be retained, + while False indicates it should be removed from the atom. + + For example: + { + "reaction_center": {"neighbors": True, "hybridization": False, ...}, + "environment": {"neighbors": True, "implicit_hydrogens": False, ...} + } + + Returns: + A list of QueryContainer objects representing the cleaned rule molecules. + """ + cleaned_rule_molecules = [] + + for rule_molecule in rule_molecules: + for reaction_molecule in reaction_molecules: + if set(rule_molecule.atoms_numbers) <= set(reaction_molecule.atoms_numbers): + query_reaction_molecule = reaction_molecule.substructure( + reaction_molecule, as_query=True + ) + query_rule_molecule = query_reaction_molecule.substructure( + rule_molecule + ) + + # Clean environment atoms + if not all( + atom_retention_details["environment"].values() + ): # if everything True, we keep all marks + local_environment_atoms = ( + set(rule_molecule.atoms_numbers) - reaction_center_atoms + ) + for atom_number in local_environment_atoms: + query_rule_molecule = clean_atom( + query_rule_molecule, + atom_retention_details["environment"], + atom_number, + ) + + # Clean reaction center atoms + if not all( + atom_retention_details["reaction_center"].values() + ): # if everything True, we keep all marks + local_reaction_center_atoms = ( + set(rule_molecule.atoms_numbers) & reaction_center_atoms + ) + for atom_number in local_reaction_center_atoms: + query_rule_molecule = clean_atom( + query_rule_molecule, + atom_retention_details["reaction_center"], + atom_number, + ) + + cleaned_rule_molecules.append(query_rule_molecule) + break + + return cleaned_rule_molecules + + +def clean_atom( + query_molecule: QueryContainer, + attributes_to_keep: Dict[str, bool], + atom_number: int, +) -> QueryContainer: + """ + Removes specified information from a given atom in a query molecule. + + :param query_molecule: The QueryContainer of molecule. + :param attributes_to_keep: Dictionary indicating which attributes to keep in the atom. + The keys should be strings representing the attribute names, and + the values should be booleans indicating whether to retain (True) + or remove (False) that attribute. Expected keys are: + - "neighbors": Indicates if neighbors of the atom should be removed. + - "hybridization": Indicates if hybridization information of the atom should be removed. + - "implicit_hydrogens": Indicates if implicit hydrogen information of the atom should be removed. + - "ring_sizes": Indicates if ring size information of the atom should be removed. + :param atom_number: The number of the atom to be modified in the query molecule. + """ + target_atom = query_molecule.atom(atom_number) + + if not attributes_to_keep["neighbors"]: + target_atom.neighbors = None + if not attributes_to_keep["hybridization"]: + target_atom.hybridization = None + if not attributes_to_keep["implicit_hydrogens"]: + target_atom.implicit_hydrogens = None + if not attributes_to_keep["ring_sizes"]: + target_atom.ring_sizes = None + + return query_molecule + + +def create_substructures_and_reagents( + reaction, rule_atoms, as_query_container, keep_reagents +): + """ + Creates substructures for reactants and products, and optionally includes reagents, based on specified parameters. + + :param reaction: The reaction object (ReactionContainer) from which to extract substructures. This object + represents a chemical reaction with specified reactants, products, and possibly reagents. + :param rule_atoms: A set of atom identifiers that define the rule atoms. These are used to identify relevant + substructures in reactants and products. + :param as_query_container: A boolean flag indicating whether the substructures should be converted to query containers. + Query containers are used for pattern matching in chemical structures. + :param keep_reagents: A boolean flag indicating whether reagents should be included in the resulting structures. + Reagents are additional substances that are present in the reaction but are not reactants or products. + + :return: A tuple containing three elements: + - A list of reactant substructures, each corresponding to a part of the reactants that matches the rule atoms. + - A list of product substructures, each corresponding to a part of the products that matches the rule atoms. + - A list of reagents, included as is or as substructures, depending on the as_query_container flag. + + The function processes the reaction to create substructures for reactants and products based on the rule atoms. + It also handles the inclusion of reagents based on the keep_reagents flag and converts these structures to query + containers if required. + """ + reactant_substructures = [ + reactant.substructure(rule_atoms.intersection(reactant.atoms_numbers)) + for reactant in reaction.reactants + if rule_atoms.intersection(reactant.atoms_numbers) + ] + + product_substructures = [ + product.substructure(rule_atoms.intersection(product.atoms_numbers)) + for product in reaction.products + if rule_atoms.intersection(product.atoms_numbers) + ] + + reagents = [] + if keep_reagents: + if as_query_container: + reagents = [ + reagent.substructure(reagent, as_query=True) + for reagent in reaction.reagents + ] + else: + reagents = reaction.reagents + + return reactant_substructures, product_substructures, reagents + + +def assemble_final_rule( + reactant_substructures, + product_substructures, + reagents, + meta_debug, + keep_metadata, + reaction, +): + """ + Assembles the final reaction rule from the provided substructures and metadata. + + :param reactant_substructures: A list of substructures derived from the reactants of the reaction. + These substructures represent parts of reactants that are relevant to the rule. + :param product_substructures: A list of substructures derived from the products of the reaction. + These substructures represent parts of products that are relevant to the rule. + :param reagents: A list of reagents involved in the reaction. These may be included as-is or as substructures, + depending on earlier processing steps. + :param meta_debug: A dictionary containing additional metadata about the reaction, such as leaving and incoming groups. + :param keep_metadata: A boolean flag indicating whether to retain the metadata associated with the reaction in the rule. + :param reaction: The original reaction object (ReactionContainer) from which the rule is being created. + + :return: A ReactionContainer object representing the assembled reaction rule. This container includes + the reactant and product substructures, reagents, and any additional metadata if keep_metadata is True. + + This function brings together the various components of a reaction rule, including reactant and product substructures, + reagents, and metadata. It creates a comprehensive representation of the reaction rule, which can be used for further + processing or analysis. + """ + rule_metadata = meta_debug if keep_metadata else {} + rule_metadata.update(reaction.meta if keep_metadata else {}) + + rule = ReactionContainer( + reactant_substructures, product_substructures, reagents, rule_metadata + ) + + if keep_metadata: + rule.name = reaction.name + + rule.flush_cache() + return rule + + +def validate_rule(rule: ReactionContainer, reaction: ReactionContainer): + """ + Validates a reaction rule by ensuring it can correctly generate the products from the reactants. + + :param rule: The reaction rule to be validated. This is a ReactionContainer object representing a chemical reaction rule, + which includes the necessary information to perform a reaction. + :param reaction: The original reaction object (ReactionContainer) against which the rule is to be validated. This object + contains the actual reactants and products of the reaction. + + :return: The validated rule if the rule correctly generates the products from the reactants. + + :raises ValueError: If the rule does not correctly generate the products from the reactants, indicating + an incorrect or incomplete rule. + + The function uses a chemical reactor to simulate the reaction based on the provided rule. It then compares + the products generated by the simulation with the actual products of the reaction. If they match, the rule + is considered valid. If not, a ValueError is raised, indicating an issue with the rule. + """ + # Create a reactor with the given rule + reactor = Reactor(rule) + try: + for result_reaction in reactor(reaction.reactants): + result_products = [] + for result_product in result_reaction.products: + tmp = result_product.copy() + try: + tmp.kekule() + if tmp.check_valence(): + continue + except InvalidAromaticRing: + continue + result_products.append(result_product) + if set(reaction.products) == set(result_products) and len( + reaction.products + ) == len(result_products): + return True + except (KeyError, IndexError): + # KeyError - iteration over reactor is finished and products are different from the original reaction + # IndexError - mistake in __contract_ions, possibly problems with charges in rule? + return False + + +def sort_rules( + rules_stats: Dict[ReactionContainer, List[int]], + min_popularity: int = 3, + single_reactant_only: bool = True, +) -> List[Tuple[ReactionContainer, List[int]]]: + """ + Sorts reaction rules based on their popularity and validation status. + + This function sorts the given rules according to their popularity (i.e., the number of times they have been + applied) and filters out rules that haven't passed reactor validation or are less popular than the specified + minimum popularity threshold. + + :param rules_stats: A dictionary where each key is a reaction rule and the value is a list of integers. + Each integer represents an index where the rule was applied. + :type rules_stats: Dict[ReactionContainer, List[int]] + + :param min_popularity: The minimum number of times a rule must be applied to be considered. Default is 3. + :type min_popularity: int + + :param single_reactant_only: Whether to keep only reaction rules with a single molecule on the right side + of reaction arrow. Default is True. + + :return: A list of tuples, where each tuple contains a reaction rule and a list of indices representing + the rule's applications. The list is sorted in descending order of the rule's popularity. + :rtype: List[Tuple[ReactionContainer, List[int]]] + """ + return sorted( + ( + (rule, indices) + for rule, indices in rules_stats.items() + if len(indices) >= min_popularity + and rule.meta["reactor_validation"] == "passed" + and (not single_reactant_only or len(rule.reactants) == 1) + ), + key=lambda x: -len(x[1]), + ) diff --git a/SynTool/chem/reaction_rules/manual/__init__.py b/SynTool/chem/reaction_rules/manual/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..1ac0aa571b5a5fc26f40aaf6f3150c02a5a74b7e --- /dev/null +++ b/SynTool/chem/reaction_rules/manual/__init__.py @@ -0,0 +1,6 @@ +from .decompositions import rules as d_rules +from .transformations import rules as t_rules + +hardcoded_rules = t_rules + d_rules + +__all__ = ["hardcoded_rules"] diff --git a/SynTool/chem/reaction_rules/manual/decompositions.py b/SynTool/chem/reaction_rules/manual/decompositions.py new file mode 100755 index 0000000000000000000000000000000000000000..8192e62a579d4e0af1c092bb9ba615fd39d4a403 --- /dev/null +++ b/SynTool/chem/reaction_rules/manual/decompositions.py @@ -0,0 +1,415 @@ +""" +Module containing hardcoded decomposition reaction rules +""" + +from CGRtools import QueryContainer, ReactionContainer +from CGRtools.periodictable import ListElement + +rules = [] + + +def prepare(): + """ + Creates and returns three query containers and appends a reaction container to the "rules" list + """ + q_ = QueryContainer() + p1_ = QueryContainer() + p2_ = QueryContainer() + rules.append(ReactionContainer((q_,), (p1_, p2_))) + return q_, p1_, p2_ + + +# R-amide/ester formation +# [C](-[N,O;D23;Zs])(-[C])=[O]>>[A].[C]-[C](-[O])=[O] +q, p1, p2 = prepare() +q.add_atom('C') +q.add_atom('C') +q.add_atom('O') +q.add_atom(ListElement(['N', 'O']), hybridization=1, neighbors=(2, 3)) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 2) +q.add_bond(2, 4, 1) + +p1.add_atom('C') +p1.add_atom('C') +p1.add_atom('O') +p1.add_atom('O', _map=5) +p1.add_bond(1, 2, 1) +p1.add_bond(2, 3, 2) +p1.add_bond(2, 5, 1) + +p2.add_atom('A', _map=4) + +# acyl group addition with aromatic carbon's case (Friedel-Crafts) +# [C;Za]-[C](-[C])=[O]>>[C].[C]-[C](-[Cl])=[O] +q, p1, p2 = prepare() +q.add_atom('C') +q.add_atom('C') +q.add_atom('O') +q.add_atom('C', hybridization=4) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 2) +q.add_bond(2, 4, 1) + +p1.add_atom('C') +p1.add_atom('C') +p1.add_atom('O') +p1.add_atom('Cl', _map=5) +p1.add_bond(1, 2, 1) +p1.add_bond(2, 3, 2) +p1.add_bond(2, 5, 1) + +p2.add_atom('C', _map=4) + +# Williamson reaction +# [C;Za]-[O]-[C;Zs;W0]>>[C]-[Br].[C]-[O] +q, p1, p2 = prepare() +q.add_atom('C', hybridization=4) +q.add_atom('O') +q.add_atom('C', hybridization=1, heteroatoms=1) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) + +p1.add_atom('C') +p1.add_atom('O') +p1.add_bond(1, 2, 1) + +p2.add_atom('C', _map=3) +p2.add_atom('Br') +p2.add_bond(3, 4, 1) + +# Buchwald-Hartwig amination +# [N;D23;Zs;W0]-[C;Za]>>[C]-[Br].[N] +q, p1, p2 = prepare() +q.add_atom('N', heteroatoms=0, hybridization=1, neighbors=(2, 3)) +q.add_atom('C', hybridization=4) +q.add_bond(1, 2, 1) + +p1.add_atom('C', _map=2) +p1.add_atom('Br') +p1.add_bond(2, 3, 1) + +p2.add_atom('N') + +# imidazole imine atom's alkylation +# [C;r5](:[N;r5]-[C;Zs;W1]):[N;D2;r5]>>[C]-[Br].[N]:[C]:[N] +q, p1, p2 = prepare() +q.add_atom('N', rings_sizes=5) +q.add_atom('C', rings_sizes=5) +q.add_atom('N', rings_sizes=5, neighbors=2) +q.add_atom('C', hybridization=1, heteroatoms=(1, 2)) +q.add_bond(1, 2, 4) +q.add_bond(2, 3, 4) +q.add_bond(1, 4, 1) + +p1.add_atom('N') +p1.add_atom('C') +p1.add_atom('N') +p1.add_bond(1, 2, 4) +p1.add_bond(2, 3, 4) + +p2.add_atom('C', _map=4) +p2.add_atom('Br') +p2.add_bond(4, 5, 1) + +# Knoevenagel condensation (nitryl and carboxyl case) +# [C]=[C](-[C]#[N])-[C](-[O])=[O]>>[C]=[O].[C](-[C]#[N])-[C](-[O])=[O] +q, p1, p2 = prepare() +q.add_atom('C') +q.add_atom('C') +q.add_atom('C') +q.add_atom('N') +q.add_atom('C') +q.add_atom('O') +q.add_atom('O') +q.add_bond(1, 2, 2) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 3) +q.add_bond(2, 5, 1) +q.add_bond(5, 6, 2) +q.add_bond(5, 7, 1) + +p1.add_atom('C', _map=2) +p1.add_atom('C') +p1.add_atom('N') +p1.add_atom('C') +p1.add_atom('O') +p1.add_atom('O') +p1.add_bond(2, 3, 1) +p1.add_bond(3, 4, 3) +p1.add_bond(2, 5, 1) +p1.add_bond(5, 6, 2) +p1.add_bond(5, 7, 1) + +p2.add_atom('C', _map=1) +p2.add_atom('O', _map=8) +p2.add_bond(1, 8, 2) + +# Knoevenagel condensation (double nitryl case) +# [C]=[C](-[C]#[N])-[C]#[N]>>[C]=[O].[C](-[C]#[N])-[C]#[N] +q, p1, p2 = prepare() +q.add_atom('C') +q.add_atom('C') +q.add_atom('C') +q.add_atom('N') +q.add_atom('C') +q.add_atom('N') +q.add_bond(1, 2, 2) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 3) +q.add_bond(2, 5, 1) +q.add_bond(5, 6, 3) + +p1.add_atom('C', _map=2) +p1.add_atom('C') +p1.add_atom('N') +p1.add_atom('C') +p1.add_atom('N') +p1.add_bond(2, 3, 1) +p1.add_bond(3, 4, 3) +p1.add_bond(2, 5, 1) +p1.add_bond(5, 6, 3) + +p2.add_atom('C', _map=1) +p2.add_atom('O', _map=8) +p2.add_bond(1, 8, 2) + +# Knoevenagel condensation (double carboxyl case) +# [C]=[C](-[C](-[O])=[O])-[C](-[O])=[O]>>[C]=[O].[C](-[C](-[O])=[O])-[C](-[O])=[O] +q, p1, p2 = prepare() +q.add_atom('C') +q.add_atom('C') +q.add_atom('C') +q.add_atom('O') +q.add_atom('O') +q.add_atom('C') +q.add_atom('O') +q.add_atom('O') +q.add_bond(1, 2, 2) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 2) +q.add_bond(3, 5, 1) +q.add_bond(2, 6, 1) +q.add_bond(6, 7, 2) +q.add_bond(6, 8, 1) + +p1.add_atom('C', _map=2) +p1.add_atom('C') +p1.add_atom('O') +p1.add_atom('O') +p1.add_atom('C') +p1.add_atom('O') +p1.add_atom('O') +p1.add_bond(2, 3, 1) +p1.add_bond(3, 4, 2) +p1.add_bond(3, 5, 1) +p1.add_bond(2, 6, 1) +p1.add_bond(6, 7, 2) +p1.add_bond(6, 8, 1) + +p2.add_atom('C', _map=1) +p2.add_atom('O', _map=9) +p2.add_bond(1, 9, 2) + +# heterocyclization with guanidine +# [c]((-[N;W0;Zs])@[n]@[c](-[N;D1])@[c;W0])@[n]@[c]-[O; D1]>>[C](-[N])(=[N])-[N].[C](#[N])-[C]-[C](-[O])=[O] +q, p1, p2 = prepare() +q.add_atom('C') +q.add_atom('N', heteroatoms=0, hybridization=1) +q.add_atom('N') +q.add_atom('C') +q.add_atom('N', neighbors=1) +q.add_atom('C', heteroatoms=0) +q.add_atom('N') +q.add_atom('C') +q.add_atom('O', neighbors=1) +q.add_bond(1, 2, 1) +q.add_bond(1, 3, 4) +q.add_bond(3, 4, 4) +q.add_bond(4, 5, 1) +q.add_bond(4, 6, 4) +q.add_bond(1, 7, 4) +q.add_bond(7, 8, 4) +q.add_bond(8, 9, 1) + +p1.add_atom('C') +p1.add_atom('N') +p1.add_atom('N') +p1.add_atom('N', _map=7) +p1.add_bond(1, 2, 1) +p1.add_bond(1, 3, 2) +p1.add_bond(1, 7, 1) + +p2.add_atom('C', _map=4) +p2.add_atom('N') +p2.add_atom('C') +p2.add_atom('C', _map=8) +p2.add_atom('O', _map=9) +p2.add_atom('O') +p2.add_bond(4, 5, 3) +p2.add_bond(4, 6, 1) +p2.add_bond(6, 8, 1) +p2.add_bond(8, 9, 2) +p2.add_bond(8, 10, 1) + +# alkylation of amine +# [C]-[N]-[C]>>[C]-[N].[C]-[Br] +q, p1, p2 = prepare() +q.add_atom('C') +q.add_atom('N') +q.add_atom('C') +q.add_atom('C') +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) +q.add_bond(2, 4, 1) + +p1.add_atom('C') +p1.add_atom('N') +p1.add_atom('C') +p1.add_bond(1, 2, 1) +p1.add_bond(2, 3, 1) + +p2.add_atom('C', _map=4) +p2.add_atom('Cl') +p2.add_bond(4, 5, 1) + +# Synthesis of guanidines +# +q, p1, p2 = prepare() +q.add_atom('N') +q.add_atom('C') +q.add_atom('N', hybridization=1) +q.add_atom('N', hybridization=1) +q.add_bond(1, 2, 2) +q.add_bond(2, 3, 1) +q.add_bond(2, 4, 1) + +p1.add_atom('N') +p1.add_atom('C') +p1.add_atom('N') +p1.add_bond(1, 2, 3) +p1.add_bond(2, 3, 1) + +p2.add_atom('N', _map=4) + +# Grignard reaction with nitrile +# +q, p1, p2 = prepare() +q.add_atom('C') +q.add_atom('C') +q.add_atom('O') +q.add_atom('C') +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 2) +q.add_bond(2, 4, 1) + +p1.add_atom('C') +p1.add_atom('C') +p1.add_atom('N') +p1.add_bond(1, 2, 1) +p1.add_bond(2, 3, 3) + +p2.add_atom('C', _map=4) +p2.add_atom('Br') +p2.add_bond(4, 5, 1) + +# Alkylation of alpha-carbon atom of nitrile +# +q, p1, p2 = prepare() +q.add_atom('N') +q.add_atom('C') +q.add_atom('C', neighbors=(3, 4)) +q.add_atom('C', hybridization=1) +q.add_bond(1, 2, 3) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 1) + +p1.add_atom('N') +p1.add_atom('C') +p1.add_atom('C') +p1.add_bond(1, 2, 3) +p1.add_bond(2, 3, 1) + +p2.add_atom('C', _map=4) +p2.add_atom('Cl') +p2.add_bond(4, 5, 1) + +# Gomberg-Bachmann reaction +# +q, p1, p2 = prepare() +q.add_atom('C', hybridization=4, heteroatoms=0) +q.add_atom('C', hybridization=4, heteroatoms=0) +q.add_bond(1, 2, 1) + +p1.add_atom('C') +p1.add_atom('N', _map=3) +p1.add_bond(1, 3, 1) + +p2.add_atom('C', _map=2) + +# Cyclocondensation +# +q, p1, p2 = prepare() +q.add_atom('N', neighbors=2) +q.add_atom('C') +q.add_atom('C') +q.add_atom('C') +q.add_atom('N') +q.add_atom('C') +q.add_atom('C') +q.add_atom('O', neighbors=1) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 1) +q.add_bond(4, 5, 2) +q.add_bond(5, 6, 1) +q.add_bond(6, 7, 1) +q.add_bond(7, 8, 2) +q.add_bond(1, 7, 1) + +p1.add_atom('N') +p1.add_atom('C') +p1.add_atom('C') +p1.add_atom('C') +p1.add_atom('O', _map=9) +p1.add_bond(1, 2, 1) +p1.add_bond(2, 3, 1) +p1.add_bond(3, 4, 1) +p1.add_bond(4, 9, 2) + +p2.add_atom('N', _map=5) +p2.add_atom('C') +p2.add_atom('C') +p2.add_atom('O') +p2.add_atom('O', _map=10) +p2.add_bond(5, 6, 1) +p2.add_bond(6, 7, 1) +p2.add_bond(7, 8, 2) +p2.add_bond(7, 10, 1) + +# heterocyclization dicarboxylic acids +# +q, p1, p2 = prepare() +q.add_atom('C', rings_sizes=(5, 6)) +q.add_atom('O') +q.add_atom(ListElement(['O', 'N'])) +q.add_atom('C', rings_sizes=(5, 6)) +q.add_atom('O') +q.add_bond(1, 2, 2) +q.add_bond(1, 3, 1) +q.add_bond(3, 4, 1) +q.add_bond(4, 5, 2) + +p1.add_atom('C') +p1.add_atom('O') +p1.add_atom('O', _map=6) +p1.add_bond(1, 2, 2) +p1.add_bond(1, 6, 1) + +p2.add_atom('C', _map=4) +p2.add_atom('O') +p2.add_atom('O', _map=7) +p2.add_bond(4, 5, 2) +p2.add_bond(4, 7, 1) + +__all__ = ['rules'] diff --git a/SynTool/chem/reaction_rules/manual/transformations.py b/SynTool/chem/reaction_rules/manual/transformations.py new file mode 100755 index 0000000000000000000000000000000000000000..6a8890a43aed1a08e4dc7c11ea0ef13129ab09da --- /dev/null +++ b/SynTool/chem/reaction_rules/manual/transformations.py @@ -0,0 +1,535 @@ +""" +Module containing hardcoded transformation reaction rules +""" + +from CGRtools import QueryContainer, ReactionContainer +from CGRtools.periodictable import ListElement + +rules = [] + + +def prepare(): + """ + Creates and returns three query containers and appends a reaction container to the "rules" list + """ + q_ = QueryContainer() + p_ = QueryContainer() + rules.append(ReactionContainer((q_,), (p_,))) + return q_, p_ + + +# aryl nitro reduction +# [C;Za;W1]-[N;D1]>>[O-]-[N+](-[C])=[O] +q, p = prepare() +q.add_atom('N', neighbors=1) +q.add_atom('C', hybridization=4, heteroatoms=1) +q.add_bond(1, 2, 1) + +p.add_atom('N', charge=1) +p.add_atom('C') +p.add_atom('O', charge=-1) +p.add_atom('O') +p.add_bond(1, 2, 1) +p.add_bond(1, 3, 1) +p.add_bond(1, 4, 2) + +# aryl nitration +# [O-]-[N+](=[O])-[C;Za;W12]>>[C] +q, p = prepare() +q.add_atom('N', charge=1) +q.add_atom('C', hybridization=4, heteroatoms=(1, 2)) +q.add_atom('O', charge=-1) +q.add_atom('O') +q.add_bond(1, 2, 1) +q.add_bond(1, 3, 1) +q.add_bond(1, 4, 2) + +p.add_atom('C', _map=2) + +# Beckmann rearrangement (oxime -> amide) +# [C]-[N;D2]-[C]=[O]>>[O]-[N]=[C]-[C] +q, p = prepare() +q.add_atom('C') +q.add_atom('N', neighbors=2) +q.add_atom('O') +q.add_atom('C') +q.add_bond(1, 2, 1) +q.add_bond(1, 3, 2) +q.add_bond(2, 4, 1) + +p.add_atom('C') +p.add_atom('N') +p.add_atom('O') +p.add_atom('C') +p.add_bond(1, 2, 2) +p.add_bond(2, 3, 1) +p.add_bond(1, 4, 1) + +# aldehydes or ketones into oxime/imine reaction +# [C;Zd;W1]=[N]>>[C]=[O] +q, p = prepare() +q.add_atom('C', hybridization=2, heteroatoms=1) +q.add_atom('N') +q.add_bond(1, 2, 2) + +p.add_atom('C') +p.add_atom('O', _map=3) +p.add_bond(1, 3, 2) + +# addition of halogen atom into phenol ring (orto) +# [C](-[Cl,F,Br,I;D1]):[C]-[O,N;Zs]>>[C](-[A]):[C] +q, p = prepare() +q.add_atom(ListElement(['O', 'N']), hybridization=1) +q.add_atom('C') +q.add_atom('C') +q.add_atom(ListElement(['Cl', 'F', 'Br', 'I']), neighbors=1) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 4) +q.add_bond(3, 4, 1) + +p.add_atom('A') +p.add_atom('C') +p.add_atom('C') +p.add_bond(1, 2, 1) +p.add_bond(2, 3, 4) + +# addition of halogen atom into phenol ring (para) +# [C](:[C]:[C]:[C]-[O,N;Zs])-[Cl,F,Br,I;D1]>>[A]-[C]:[C]:[C]:[C] +q, p = prepare() +q.add_atom(ListElement(['O', 'N']), hybridization=1) +q.add_atom('C') +q.add_atom('C') +q.add_atom('C') +q.add_atom('C') +q.add_atom(ListElement(['Cl', 'F', 'Br', 'I']), neighbors=1) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 4) +q.add_bond(3, 4, 4) +q.add_bond(4, 5, 4) +q.add_bond(5, 6, 1) + +p.add_atom('A') +p.add_atom('C') +p.add_atom('C') +p.add_atom('C') +p.add_atom('C') +p.add_bond(1, 2, 1) +p.add_bond(2, 3, 4) +p.add_bond(3, 4, 4) +p.add_bond(4, 5, 4) + +# hard reduction of Ar-ketones +# [C;Za]-[C;D2;Zs;W0]>>[C]-[C]=[O] +q, p = prepare() +q.add_atom('C', hybridization=4) +q.add_atom('C', hybridization=1, neighbors=2, heteroatoms=0) +q.add_bond(1, 2, 1) + +p.add_atom('C') +p.add_atom('C') +p.add_atom('O') +p.add_bond(1, 2, 1) +p.add_bond(2, 3, 2) + +# reduction of alpha-hydroxy pyridine +# [C;W1]:[N;H0;r6]>>[C](:[N])-[O] +q, p = prepare() +q.add_atom('C', heteroatoms=1) +q.add_atom('N', rings_sizes=6, hydrogens=0) +q.add_bond(1, 2, 4) + +p.add_atom('C') +p.add_atom('N') +p.add_atom('O') +p.add_bond(1, 2, 4) +p.add_bond(1, 3, 1) + +# Reduction of alkene +# [C]-[C;D23;Zs;W0]-[C;D123;Zs;W0]>>[C](-[C])=[C] +q, p = prepare() +q.add_atom('C') +q.add_atom('C', heteroatoms=0, neighbors=(2, 3), hybridization=1) +q.add_atom('C', heteroatoms=0, neighbors=(1, 2, 3), hybridization=1) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) + +p.add_atom('C') +p.add_atom('C') +p.add_atom('C') +p.add_bond(1, 2, 1) +p.add_bond(2, 3, 2) + +# Kolbe-Schmitt reaction +# [C](:[C]-[O;D1])-[C](=[O])-[O;D1]>>[C](-[O]):[C] +q, p = prepare() +q.add_atom('O', neighbors=1) +q.add_atom('C') +q.add_atom('C') +q.add_atom('C') +q.add_atom('O', neighbors=1) +q.add_atom('O') +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 4) +q.add_bond(3, 4, 1) +q.add_bond(4, 5, 1) +q.add_bond(4, 6, 2) + +p.add_atom('O') +p.add_atom('C') +p.add_atom('C') +p.add_bond(1, 2, 1) +p.add_bond(2, 3, 4) + +# reduction of carboxylic acid +# [O;D1]-[C;D2]-[C]>>[C]-[C](-[O])=[O] +q, p = prepare() +q.add_atom('C') +q.add_atom('C', neighbors=2) +q.add_atom('O', neighbors=1) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) + +p.add_atom('C') +p.add_atom('C') +p.add_atom('O') +p.add_atom('O') +p.add_bond(1, 2, 1) +p.add_bond(2, 3, 1) +p.add_bond(2, 4, 2) + +# halogenation of alcohols +# [C;Zs]-[Cl,Br;D1]>>[C]-[O] +q, p = prepare() +q.add_atom('C', hybridization=1, heteroatoms=1) +q.add_atom(ListElement(['Cl', 'Br']), neighbors=1) +q.add_bond(1, 2, 1) + +p.add_atom('C') +p.add_atom('O', _map=3) +p.add_bond(1, 3, 1) + +# Kolbe nitrilation +# [N]#[C]-[C;Zs;W0]>>[Br]-[C] +q, p = prepare() +q.add_atom('C', heteroatoms=0, hybridization=1) +q.add_atom('C') +q.add_atom('N') +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 3) + +p.add_atom('C') +p.add_atom('Br', _map=4) +p.add_bond(1, 4, 1) + +# Nitrile hydrolysis +# [O;D1]-[C]=[O]>>[N]#[C] +q, p = prepare() +q.add_atom('C') +q.add_atom('O', neighbors=1) +q.add_atom('O') +q.add_bond(1, 2, 1) +q.add_bond(1, 3, 2) + +p.add_atom('C') +p.add_atom('N', _map=4) +p.add_bond(1, 4, 3) + +# sulfamidation +# [c]-[S](=[O])(=[O])-[N]>>[c] +q, p = prepare() +q.add_atom('C', hybridization=4) +q.add_atom('S') +q.add_atom('O') +q.add_atom('O') +q.add_atom('N', neighbors=1) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 2) +q.add_bond(2, 4, 2) +q.add_bond(2, 5, 1) + +p.add_atom('C') + +# Ring expansion rearrangement +# +q, p = prepare() +q.add_atom('C') +q.add_atom('N') +q.add_atom('C', rings_sizes=6) +q.add_atom('C') +q.add_atom('O') +q.add_atom('C') +q.add_atom('C') +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 1) +q.add_bond(4, 5, 2) +q.add_bond(3, 6, 1) +q.add_bond(4, 7, 1) + +p.add_atom('C') +p.add_atom('N') +p.add_atom('C') +p.add_atom('C') +p.add_atom('O') +p.add_atom('C') +p.add_atom('C') +p.add_bond(1, 2, 1) +p.add_bond(2, 3, 2) +p.add_bond(3, 4, 1) +p.add_bond(4, 5, 1) +p.add_bond(4, 6, 1) +p.add_bond(4, 7, 1) + +# hydrolysis of bromide alkyl +# +q, p = prepare() +q.add_atom('C', hybridization=1) +q.add_atom('O', neighbors=1) +q.add_bond(1, 2, 1) + +p.add_atom('C') +p.add_atom('Br') +p.add_bond(1, 2, 1) + +# Condensation of ketones/aldehydes and amines into imines +# +q, p = prepare() +q.add_atom('N', neighbors=(1, 2)) +q.add_atom('C', neighbors=(2, 3), heteroatoms=1) +q.add_bond(1, 2, 2) + +p.add_atom('C', _map=2) +p.add_atom('O') +p.add_bond(2, 3, 2) + +# Halogenation of alkanes +# +q, p = prepare() +q.add_atom('C', hybridization=1) +q.add_atom(ListElement(['F', 'Cl', 'Br'])) +q.add_bond(1, 2, 1) + +p.add_atom('C') + +# heterocyclization +# +q, p = prepare() +q.add_atom('N', heteroatoms=0, hybridization=1, neighbors=(2, 3)) +q.add_atom('C', heteroatoms=2) +q.add_atom('N', heteroatoms=0, neighbors=2) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 2) + +p.add_atom('N') +p.add_atom('C') +p.add_atom('N') +p.add_atom('O') +p.add_bond(1, 2, 1) +p.add_bond(2, 4, 2) + +# Reduction of nitrile +# +q, p = prepare() +q.add_atom('N', neighbors=1) +q.add_atom('C') +q.add_atom('C', hybridization=1) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) + +p.add_atom('N') +p.add_atom('C') +p.add_atom('C') +p.add_bond(1, 2, 3) +p.add_bond(2, 3, 1) + +# SPECIAL CASE +# Reduction of nitrile into methylamine +# +q, p = prepare() +q.add_atom('C', neighbors=1) +q.add_atom('N', neighbors=2) +q.add_atom('C') +q.add_atom('C', hybridization=1) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 1) + +p.add_atom('N', _map=2) +p.add_atom('C') +p.add_atom('C') +p.add_bond(2, 3, 3) +p.add_bond(3, 4, 1) + +# methylation of amides +# +q, p = prepare() +q.add_atom('O') +q.add_atom('C') +q.add_atom('N') +q.add_atom('C', neighbors=1) +q.add_bond(1, 2, 2) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 1) + +p.add_atom('O') +p.add_atom('C') +p.add_atom('N') +p.add_bond(1, 2, 2) +p.add_bond(2, 3, 1) + +# hydrocyanation of alkenes +# +q, p = prepare() +q.add_atom('C', hybridization=1) +q.add_atom('C') +q.add_atom('C') +q.add_atom('N') +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 3) + +p.add_atom('C') +p.add_atom('C') +p.add_bond(1, 2, 2) + +# decarbocylation (alpha atom of nitrile) +# +q, p = prepare() +q.add_atom('N') +q.add_atom('C') +q.add_atom('C', neighbors=2) +q.add_bond(1, 2, 3) +q.add_bond(2, 3, 1) + +p.add_atom('N') +p.add_atom('C') +p.add_atom('C') +p.add_atom('C') +p.add_atom('O') +p.add_atom('O') +p.add_bond(1, 2, 3) +p.add_bond(2, 3, 1) +p.add_bond(3, 4, 1) +p.add_bond(4, 5, 2) +p.add_bond(4, 6, 1) + +# Bichler-Napieralski reaction +# +q, p = prepare() +q.add_atom('C', rings_sizes=(6,)) +q.add_atom('C', rings_sizes=(6,)) +q.add_atom('N', rings_sizes=(6,), neighbors=2) +q.add_atom('C') +q.add_atom('C') +q.add_atom('C') +q.add_atom('O') +q.add_atom('O') +q.add_atom('C') +q.add_atom('O', neighbors=1) +q.add_bond(1, 2, 4) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 1) +q.add_bond(4, 5, 2) +q.add_bond(5, 6, 1) +q.add_bond(6, 7, 2) +q.add_bond(6, 8, 1) +q.add_bond(5, 9, 4) +q.add_bond(9, 10, 1) +q.add_bond(1, 9, 1) + +p.add_atom('C') +p.add_atom('C') +p.add_atom('N') +p.add_atom('C') +p.add_atom('C') +p.add_atom('C') +p.add_atom('O') +p.add_atom('O') +p.add_atom('C') +p.add_atom('O') +p.add_atom('O') +p.add_bond(1, 2, 4) +p.add_bond(2, 3, 1) +p.add_bond(3, 4, 1) +p.add_bond(4, 5, 2) +p.add_bond(5, 6, 1) +p.add_bond(6, 7, 2) +p.add_bond(6, 8, 1) +p.add_bond(5, 9, 1) +p.add_bond(9, 10, 2) +p.add_bond(9, 11, 1) + +# heterocyclization in Prins reaction +# +q, p = prepare() +q.add_atom('C') +q.add_atom('O') +q.add_atom('C') +q.add_atom(ListElement(['N', 'O']), neighbors=2) +q.add_atom('C') +q.add_atom('C') +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 1) +q.add_bond(4, 5, 1) +q.add_bond(5, 6, 1) +q.add_bond(1, 6, 1) + +p.add_atom('C') +p.add_atom('C', _map=5) +p.add_bond(1, 5, 2) + +# recyclization of tetrahydropyran through an opening the ring and dehydration +# +q, p = prepare() +q.add_atom('C') +q.add_atom('C') +q.add_atom('C') +q.add_atom(ListElement(['N', 'O'])) +q.add_atom('C') +q.add_atom('C') +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) +q.add_bond(3, 4, 1) +q.add_bond(4, 5, 1) +q.add_bond(5, 6, 1) +q.add_bond(1, 6, 2) + +p.add_atom('C') +p.add_atom('C') +p.add_atom('C') +p.add_atom('A') +p.add_atom('C') +p.add_atom('C') +p.add_atom('O') +p.add_bond(1, 2, 1) +p.add_bond(1, 7, 1) +p.add_bond(3, 7, 1) +p.add_bond(3, 4, 1) +p.add_bond(4, 5, 1) +p.add_bond(5, 6, 1) +p.add_bond(1, 6, 1) + +# alkenes + h2o/hHal +# +q, p = prepare() +q.add_atom('C', hybridization=1) +q.add_atom('C', hybridization=1) +q.add_atom(ListElement(['O', 'F', 'Cl', 'Br', 'I']), neighbors=1) +q.add_bond(1, 2, 1) +q.add_bond(2, 3, 1) + +p.add_atom('C') +p.add_atom('C') +p.add_bond(1, 2, 2) + +# methylation of dimethylamines +# +q, p = prepare() +q.add_atom('C', neighbors=1) +q.add_atom('N', neighbors=3) +q.add_bond(1, 2, 1) + +p.add_atom('N', _map=2) + +__all__ = ['rules'] diff --git a/SynTool/chem/retron.py b/SynTool/chem/retron.py new file mode 100755 index 0000000000000000000000000000000000000000..3c763663edbc2699c3ea215b3948dde23f5bd721 --- /dev/null +++ b/SynTool/chem/retron.py @@ -0,0 +1,132 @@ +""" +Module containing a class Retron that represents a retron (extend molecule object) in the search tree +""" + +from CGRtools.containers import MoleculeContainer +from CGRtools.exceptions import InvalidAromaticRing + +from SynTool.chem.utils import safe_canonicalization + + +class Retron: + """ + Retron class is used to extend the molecule behavior needed for interaction with a tree in MCTS + """ + + def __init__(self, molecule: MoleculeContainer, canonicalize: bool = True): + """ + It initializes a Retron object with a molecule container as a parameter. + + :param molecule: The `molecule` parameter is of type `MoleculeContainer`. + :type molecule: MoleculeContainer + """ + self._molecule = safe_canonicalization(molecule) if canonicalize else molecule + self._mapping = None + self.prev_retrons = [] + + def __len__(self): + """ + Return the number of atoms in Retron. + """ + return len(self._molecule) + + def __hash__(self): + """ + Returns the hash value of Retron. + """ + return hash(self._molecule) + + def __str__(self): + return str(self._molecule) + + def __eq__(self, other: "Retron"): + """ + The function checks if the current Retron is equal to another Retron of the same class. + + :param other: The "other" parameter is a reference to another object of the same class "Retron". It is used to + compare the current Retron with the other Retron to check if they are equal + :type other: "Retron" + """ + return self._molecule == other._molecule + + def validate_molecule(self): + molecule = self._molecule.copy() + try: + molecule.kekule() + if molecule.check_valence(): + return False + molecule.thiele() + except InvalidAromaticRing: + return False + return True + + @property + def molecule(self) -> MoleculeContainer: + """ + Returns a remapped MoleculeContainer object if self._mapping=True. + """ + if self._mapping: + remapped = self._molecule.copy() + try: + remapped = self._molecule.remap(self._mapping, copy=True) + except ValueError: + pass + return remapped + return self._molecule.copy() + + def __repr__(self): + """ + Returns a SMILES of the retron + """ + return str(self._molecule) + + def is_building_block(self, stock, min_mol_size=6): + """ + The function checks if a Retron is a building block. + + :param min_mol_size: + :param stock: The list of building blocks. Each building block is represented by a smiles. + """ + if len(self._molecule) <= min_mol_size: + return True + else: + return str(self._molecule) in stock + + +def compose_retrons(retrons: list = None, exclude_small=True, min_mol_size: int = 6 + ) -> MoleculeContainer: + """ + The function takes a list of retrons, excludes small retrons if specified, and composes them into a single molecule. + This molecule is used for the prediction of synthesisability of the characterizing the possible success of the path + including the nodes with the given retrons. + + :param retrons: The list of retrons to be composed. + :type retrons: list + :param exclude_small: The parameter that determines whether small retrons should be + excluded from the composition process. If `exclude_small` is set to `True`, only retrons with a length greater than + min_mol_size will be considered for composition. + :param min_mol_size: parameter used with exclude_small + :return: A composed retrons as a MoleculeContainer object. + """ + + if len(retrons) == 1: + return retrons[0].molecule + elif len(retrons) > 1: + if exclude_small: + big_retrons = [ + retron for retron in retrons if len(retron.molecule) > min_mol_size + ] + if big_retrons: + retrons = big_retrons + tmp_mol = retrons[0].molecule.copy() + transition_mapping = {} + for mol in retrons[1:]: + for n, atom in mol.molecule.atoms(): + new_number = tmp_mol.add_atom(atom.atomic_symbol) + transition_mapping[n] = new_number + for atom, neighbor, bond in mol.molecule.bonds(): + tmp_mol.add_bond( + transition_mapping[atom], transition_mapping[neighbor], bond + ) + transition_mapping = {} + return tmp_mol diff --git a/SynTool/chem/utils.py b/SynTool/chem/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..78e641fb90ea2e3d1ab9929f58ec826f50666708 --- /dev/null +++ b/SynTool/chem/utils.py @@ -0,0 +1,227 @@ +from typing import List, Iterable, Tuple, Union + +from CGRtools.containers import MoleculeContainer, ReactionContainer, QueryContainer +from CGRtools.exceptions import InvalidAromaticRing + + +def query_to_mol(query: QueryContainer) -> MoleculeContainer: + """ + Converts a QueryContainer object into a MoleculeContainer object. + + :param query: A QueryContainer object representing the query structure. + :return: A MoleculeContainer object that replicates the structure of the query. + """ + new_mol = MoleculeContainer() + for n, atom in query.atoms(): + new_mol.add_atom(atom.atomic_symbol, n, charge=atom.charge, is_radical=atom.is_radical) + for i, j, bond in query.bonds(): + new_mol.add_bond(i, j, int(bond)) + return new_mol + + +def reaction_query_to_reaction(rule: ReactionContainer) -> ReactionContainer: + """ + Converts a ReactionContainer object with query structures into a ReactionContainer with molecular structures. + + :param rule: A ReactionContainer object where reactants and products are QueryContainer objects. + :return: A new ReactionContainer + :return: A new ReactionContainer object where reactants and products are MoleculeContainer objects. + """ + reactants = [query_to_mol(q) for q in rule.reactants] + products = [query_to_mol(q) for q in rule.products] + reagents = [query_to_mol(q) for q in rule.reagents] # Assuming reagents are also part of the rule + reaction = ReactionContainer(reactants, products, reagents, rule.meta) + reaction.name = rule.name + return reaction + + +def unite_molecules(molecules: Iterable[MoleculeContainer]) -> MoleculeContainer: + """ + Unites a list of MoleculeContainer objects into a single MoleculeContainer. + + This function takes multiple molecules and combines them into one larger molecule. + The first molecule in the list is taken as the base, and subsequent molecules are united with it sequentially. + + :param molecules: A list of MoleculeContainer objects to be united. + :return: A single MoleculeContainer object representing the union of all input molecules. + """ + new_mol = MoleculeContainer() + for mol in molecules: + new_mol = new_mol.union(mol) + return new_mol + + +def safe_canonicalization(molecule: MoleculeContainer): + """ + Attempts to canonicalize a molecule, handling any exceptions. + + This function tries to canonicalize the given molecule. + If the canonicalization process fails due to an InvalidAromaticRing exception, + it safely returns the original molecule. + + :param molecule: The given molecule to be canonicalized. + :return: The canonicalized molecule if successful, otherwise the original molecule. + """ + molecule._atoms = dict(sorted(molecule._atoms.items())) + + tmp = molecule.copy() + try: + tmp.canonicalize() + return tmp + except InvalidAromaticRing: + return molecule + + +def split_molecules(molecules: Iterable, number_of_atoms: int) -> Tuple[List, List]: + """ + Splits molecules according to the number of heavy atoms. + + :param molecules: Iterable of molecules. + :param number_of_atoms: Threshold for splitting molecules. + :return: Tuple of lists containing "big" molecules and "small" molecules. + """ + big_molecules, small_molecules = [], [] + for molecule in molecules: + if len(molecule) > number_of_atoms: + big_molecules.append(molecule) + else: + small_molecules.append(molecule) + + return big_molecules, small_molecules + + +def remove_small_molecules( + reaction: ReactionContainer, + number_of_atoms: int = 6, + small_molecules_to_meta: bool = True +) -> Union[ReactionContainer, None]: + """ + Processes a reaction by removing small molecules. + + :param reaction: ReactionContainer object. + :param number_of_atoms: Molecules with the number of atoms equal to or below this will be removed. + :param small_molecules_to_meta: If True, deleted molecules are saved to meta. + :return: Processed ReactionContainer without small molecules. + """ + new_reactants, small_reactants = split_molecules(reaction.reactants, number_of_atoms) + new_products, small_products = split_molecules(reaction.products, number_of_atoms) + + if sum(len(mol) for mol in new_reactants) == 0 or sum(len(mol) for mol in new_reactants) == 0: + return None + + new_reaction = ReactionContainer(new_reactants, new_products, reaction.reagents, reaction.meta) + new_reaction.name = reaction.name + + if small_molecules_to_meta: + united_small_reactants = unite_molecules(small_reactants) + new_reaction.meta["small_reactants"] = str(united_small_reactants) + + united_small_products = unite_molecules(small_products) + new_reaction.meta["small_products"] = str(united_small_products) + + return new_reaction + + +def rebalance_reaction(reaction: ReactionContainer) -> ReactionContainer: + """ + Rebalances the reaction by assembling CGR and then decomposing it. Works for all reactions for which the correct + CGR can be assembled + + :param reaction: a reaction object + :return: a rebalanced reaction + """ + tmp_reaction = ReactionContainer(reaction.reactants, reaction.products) + cgr = ~tmp_reaction + reactants, products = ~cgr + rebalanced_reaction = ReactionContainer(reactants.split(), products.split(), reaction.reagents, reaction.meta) + rebalanced_reaction.name = reaction.name + return rebalanced_reaction + + +def reverse_reaction(reaction: ReactionContainer) -> ReactionContainer: + """ + Reverses given reaction + + :param reaction: a reaction object + :return: the reversed reaction + """ + reversed_reaction = ReactionContainer(reaction.products, reaction.reactants, reaction.reagents, reaction.meta) + reversed_reaction.name = reaction.name + + return reversed_reaction + + +def remove_reagents( + reaction: ReactionContainer, + keep_reagents: bool = True, + reagents_max_size: int = 7 +) -> Union[ReactionContainer, None]: + """ + Removes reagents (not changed molecules or molecules not involved in the reaction) from reactants and products + + :param reaction: a reaction object + :param keep_reagents: if True, the reagents are written to ReactionContainer + :param reagents_max_size: max size of molecules that are called reagents, bigger are deleted + :return: cleaned reaction + """ + not_changed_molecules = set(reaction.reactants).intersection(reaction.products) + + cgr = ~reaction + center_atoms = set(cgr.center_atoms) + + new_reactants = [] + new_products = [] + new_reagents = [] + + for molecule in reaction.reactants: + if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules: + new_reagents.append(molecule) + else: + new_reactants.append(molecule) + + for molecule in reaction.products: + if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules: + new_reagents.append(molecule) + else: + new_products.append(molecule) + + if sum(len(mol) for mol in new_reactants) == 0 or sum(len(mol) for mol in new_reactants) == 0: + return None + + if keep_reagents: + new_reagents = {molecule for molecule in new_reagents if len(molecule) <= reagents_max_size} + else: + new_reagents = [] + + new_reaction = ReactionContainer(new_reactants, new_products, new_reagents, reaction.meta) + new_reaction.name = reaction.name + + return new_reaction + + +def to_reaction_smiles_record(reaction): + if isinstance(reaction, str): + return reaction + + reaction_record = [format(reaction, "m")] + sorted_meta = sorted(reaction.meta.items(), key=lambda x: x[0]) + for _, meta_info in sorted_meta: + # meta_info = str(meta_info) + meta_info = '' # TODO decide what to do with meta + meta_info = ";".join(meta_info.split("\n")) + reaction_record.append(str(meta_info)) + # return "\t".join(reaction_record) + "\n" + return "".join(reaction_record) + + +def cgr_from_rule(rule: ReactionContainer): + reaction_rule = reaction_query_to_reaction(rule) + cgr_rule = ~reaction_rule + return cgr_rule + + +def hash_from_rule(reaction_rule: ReactionContainer): + reactants_hash = tuple(sorted(hash(r) for r in reaction_rule.reactants)) + reagents_hash = tuple(sorted(hash(r) for r in reaction_rule.reagents)) + products_hash = tuple(sorted(hash(r) for r in reaction_rule.products)) + return hash((reactants_hash, reagents_hash, products_hash)) diff --git a/SynTool/interfaces/__init__.py b/SynTool/interfaces/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SynTool/interfaces/__pycache__/__init__.cpython-310.pyc b/SynTool/interfaces/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1690b536c1c159a54b0386e9305953efc009d5a Binary files /dev/null and b/SynTool/interfaces/__pycache__/__init__.cpython-310.pyc differ diff --git a/SynTool/interfaces/__pycache__/visualisation.cpython-310.pyc b/SynTool/interfaces/__pycache__/visualisation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7011e8fef7284cfb2f84aedee236c4f425d6fd0c Binary files /dev/null and b/SynTool/interfaces/__pycache__/visualisation.cpython-310.pyc differ diff --git a/SynTool/interfaces/cli.py b/SynTool/interfaces/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..6cdde6f6f99c3e60a7d002f9a87d202e0e37f304 --- /dev/null +++ b/SynTool/interfaces/cli.py @@ -0,0 +1,530 @@ +""" +Module containing commands line scripts for training and planning mode +""" + +import os +import shutil +import yaml +import warnings +from pathlib import Path + +import click +import gdown + +from SynTool.chem.data.cleaning import reactions_cleaner +from SynTool.chem.data.filtering import filter_reactions, ReactionCheckConfig +from SynTool.utils.loading import standardize_building_blocks +from SynTool.chem.reaction_rules.extraction import extract_rules_from_reactions +from SynTool.mcts.search import tree_search +from SynTool.ml.training.reinforcement import run_reinforcement_tuning +from SynTool.ml.training.supervised import create_policy_dataset, run_policy_training +from SynTool.utils.config import ReinforcementConfig, TreeConfig, PolicyNetworkConfig, ValueNetworkConfig +from SynTool.utils.config import ReactionStandardizationConfig, RuleExtractionConfig +from SynTool.chem.data.mapping import remove_reagents_and_map_from_file + +warnings.filterwarnings("ignore") + + +@click.group(name="syntool") +def syntool(): + pass + + +@syntool.command(name="download_planning_data") +@click.option( + "--root_dir", + required=True, + type=click.Path(exists=True), + help="Path to the reaction database file that will be mapped.", +) +def download_planning_data_cli(root_dir='.'): + """ + Downloads data for retrosythesis planning + """ + remote_id = "1ygq9BvQgH2Tq_rL72BvSOdASSSbPFTsL" + output = os.path.join(root_dir, "syntool_planning_data.zip") + # + gdown.download(output=output, id=remote_id, quiet=False) + shutil.unpack_archive(output, root_dir) + # + os.remove(output) + + +@syntool.command(name='download_training_data') +@click.option( + "--root_dir", + required=True, + type=click.Path(exists=True), + help="Path to the reaction database file that will be mapped.", +) +def download_training_data_cli(root_dir='.'): + """ + Downloads data for retrosythetic models training + """ + remote_id = "1ckhO1l6xud0_bnC0rCDMkIlKRUMG_xs8" + output = os.path.join(root_dir, "syntool_training_data.zip") + # + gdown.download(output=output, id=remote_id, quiet=False) + shutil.unpack_archive(output, root_dir) + # + os.remove(output) + + +@syntool.command(name="building_blocks") +@click.option( + "--input", + "input_file", + required=True, + type=click.Path(exists=True), + help="Path to the reaction database file that will be mapped.", +) +@click.option( + "--output", + "output_file", + required=True, + type=click.Path(), + help="File where the results will be stored.", +) +def building_blocks_cli(input_file, output_file): + """ + Standardizes building blocks + """ + + standardize_building_blocks(input_file=input_file, output_file=output_file) + + +@syntool.command(name="reaction_mapping") +@click.option( + "--config", + "config_path", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for mapping and standardizing reactions.", +) +@click.option( + "--input", + "input_file", + required=True, + type=click.Path(exists=True), + help="Path to the reaction database file that will be mapped.", +) +@click.option( + "--output", + "output_file", + default=Path("reaction_data_standardized.smi"), + type=click.Path(), + help="File where the results will be stored.", +) +def reaction_mapping_cli(config_path, input_file, output_file): + """ + Reaction data mapping + """ + + stand_config = ReactionStandardizationConfig.from_yaml(config_path) + remove_reagents_and_map_from_file(input_file=input_file, output_file=output_file, keep_reagent=stand_config.keep_reagents) + + +@syntool.command(name="reaction_standardizing") +@click.option( + "--config", + "config_path", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for mapping and standardizing reactions.", +) +@click.option( + "--input", + "input_file", + required=True, + type=click.Path(exists=True), + help="Path to the reaction database file that will be mapped.", +) +@click.option( + "--output", + "output_file", + type=click.Path(), + help="File where the results will be stored.", +) +@click.option( + "--num_cpus", + default=8, + type=int, + help="Number of CPUs to use for processing. Defaults to 1.", +) +def reaction_standardizing_cli(config_path, input_file, output_file, num_cpus): + """ + Standardizes reactions and remove duplicates + """ + + stand_config = ReactionStandardizationConfig.from_yaml(config_path) + reactions_cleaner(config=stand_config, + input_file=input_file, + output_file=output_file, + num_cpus=num_cpus) + + +@syntool.command(name="reaction_filtering") +@click.option( + "--config", + "config_path", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for filtering reactions.", +) +@click.option( + "--input", + "input_file", + required=True, + type=click.Path(exists=True), + help="Path to the reaction database file that will be mapped.", +) +@click.option( + "--output", + "output_file", + default=Path("./"), + type=click.Path(), + help="File where the results will be stored.", +) +@click.option( + "--append_results", + is_flag=True, + default=False, + help="If set, results will be appended to existing files. By default, new files are created.", +) +@click.option( + "--batch_size", + default=100, + type=int, + help="Size of the batch for processing reactions. Defaults to 10.", +) +@click.option( + "--num_cpus", + default=8, + type=int, + help="Number of CPUs to use for processing. Defaults to 1.", +) +def reaction_filtering_cli(config_path, + input_file, + output_file, + append_results, + batch_size, + num_cpus): + """ + Filters erroneous reactions + """ + reaction_check_config = ReactionCheckConfig().from_yaml(config_path) + filter_reactions( + config=reaction_check_config, + reaction_database_path=input_file, + result_reactions_file_name=output_file, + append_results=append_results, + num_cpus=num_cpus, + batch_size=batch_size, + ) + + +@syntool.command(name="rule_extracting") +@click.option( + "--config", + "config_path", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for reaction rules extraction.", +) +@click.option( + "--input", + "input_file", + required=True, + type=click.Path(exists=True), + help="Path to the reaction database file that will be mapped.", +) +@click.option( + "--output", + "output_file", + required=True, + type=click.Path(), + help="File where the results will be stored.", +) +@click.option( + "--batch_size", + default=100, + type=int, + help="Size of the batch for processing reactions. Defaults to 10.", +) +@click.option( + "--num_cpus", + default=4, + type=int, + help="Number of CPUs to use for processing. Defaults to 4.", +) +def rule_extracting_cli( + config_path, + input_file, + output_file, + num_cpus, + batch_size, +): + """ + Extracts reaction rules + """ + + reaction_rule_config = RuleExtractionConfig.from_yaml(config_path) + extract_rules_from_reactions(config=reaction_rule_config, + reaction_file=input_file, + rules_file_name=output_file, + num_cpus=num_cpus, + batch_size=batch_size) + + +@syntool.command(name="supervised_ranking_policy_training") +@click.option( + "--config", + "config_path", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--reaction_data", + required=True, + type=click.Path(exists=True), + help="Path to the reaction database file that will be mapped.", +) +@click.option( + "--reaction_rules", + required=True, + type=click.Path(exists=True), + help="Path to the reaction database file that will be mapped.", +) +@click.option( + "--results_dir", + default=Path("."), + type=click.Path(), + help="Root directory where the results will be stored.", +) +@click.option( + "--num_cpus", + default=4, + type=int, + help="Number of CPUs to use for processing. Defaults to 4.", +) +def supervised_ranking_policy_training_cli(config_path, reaction_data, reaction_rules, results_dir, num_cpus): + """ + Trains ranking policy network + """ + + policy_config = PolicyNetworkConfig.from_yaml(config_path) + + policy_dataset_file = os.path.join(results_dir, 'policy_dataset.dt') + + datamodule = create_policy_dataset(reaction_rules_path=reaction_rules, + molecules_or_reactions_path=reaction_data, + output_path=policy_dataset_file, + dataset_type='ranking', + batch_size=policy_config.batch_size, + num_cpus=num_cpus) + + run_policy_training(datamodule, config=policy_config, results_path=results_dir) + + +@syntool.command(name="supervised_filtering_policy_training") +@click.option( + "--config", + "config_path", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--molecules_file", + required=True, + type=click.Path(exists=True), + help="Path to the molecules database file that will be mapped.", +) +@click.option( + "--reaction_rules", + required=True, + type=click.Path(exists=True), + help="Path to the reaction database file that will be mapped.", +) +@click.option( + "--results_dir", + default=Path("."), + type=click.Path(), + help="Root directory where the results will be stored.", +) +@click.option( + "--num_cpus", + default=8, + type=int, + help="Number of CPUs to use for processing. Defaults to 1.", +) +def supervised_filtering_policy_training_cli(config_path, molecules_file, reaction_rules, results_dir, num_cpus): + """ + Trains filtering policy network + """ + + policy_config = PolicyNetworkConfig.from_yaml(config_path) + + policy_dataset_file = os.path.join(results_dir, 'policy_dataset.ckpt') + datamodule = create_policy_dataset(reaction_rules_path=reaction_rules, + molecules_or_reactions_path=molecules_file, + output_path=policy_dataset_file, + dataset_type='filtering', + batch_size=policy_config.batch_size, + num_cpus=num_cpus) + + run_policy_training(datamodule, config=policy_config, results_path=results_dir) + + +@syntool.command(name="reinforcement_value_network_training") +@click.option( + "--config", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--targets", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--reaction_rules", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--building_blocks", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--policy_network", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--value_network", + default=None, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--results_dir", + default='.', + type=click.Path(exists=False), + help="Path to the configuration file. This file contains settings for policy training.", +) +def reinforcement_value_network_training_cli(config, + targets, + reaction_rules, + building_blocks, + policy_network, + value_network, + results_dir): + """ + Trains value network with reinforcement learning + """ + + with open(config, "r") as file: + config = yaml.safe_load(file) + + policy_config = PolicyNetworkConfig.from_dict(config['node_expansion']) + policy_config.weights_path = policy_network + + value_config = ValueNetworkConfig.from_dict(config['value_network']) + if value_network is None: + value_config.weights_path = os.path.join(results_dir, 'weights', 'value_network.ckpt') + + tree_config = TreeConfig.from_dict(config['tree']) + reinforce_config = ReinforcementConfig.from_dict(config['reinforcement']) + + run_reinforcement_tuning(targets_path=targets, + tree_config=tree_config, + policy_config=policy_config, + value_config=value_config, + reinforce_config=reinforce_config, + reaction_rules_path=reaction_rules, + building_blocks_path=building_blocks, + results_root=results_dir) + + +@syntool.command(name="planning") +@click.option( + "--config", + "config_path", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--targets", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--reaction_rules", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--building_blocks", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--policy_network", + required=True, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--value_network", + default=None, + type=click.Path(exists=True), + help="Path to the configuration file. This file contains settings for policy training.", +) +@click.option( + "--results_dir", + default='.', + type=click.Path(exists=False), + help="Path to the configuration file. This file contains settings for policy training.", +) +def planning_cli(config_path, + targets, + reaction_rules, + building_blocks, + policy_network, + value_network, + results_dir): + """ + Runs retrosynthesis planning + """ + + with open(config_path, "r") as file: + config = yaml.safe_load(file) + + tree_config = TreeConfig.from_dict({**config['tree'], **config['node_evaluation']}) + policy_config = PolicyNetworkConfig.from_dict({**config['node_expansion'], **{'weights_path': policy_network}}) + + tree_search(targets_path=targets, + tree_config=tree_config, + policy_config=policy_config, + reaction_rules_path=reaction_rules, + building_blocks_path=building_blocks, + value_weights_path=value_network, + results_root=results_dir) + + +if __name__ == '__main__': + syntool() + + diff --git a/SynTool/interfaces/cli.py.bk b/SynTool/interfaces/cli.py.bk new file mode 100644 index 0000000000000000000000000000000000000000..550f128644c43402dabe3db437e847fb049c24f9 --- /dev/null +++ b/SynTool/interfaces/cli.py.bk @@ -0,0 +1,241 @@ +""" +Module containing commands line scripts for training and planning mode +""" + +import warnings +import os +import shutil +from pathlib import Path +import click +import gdown +from datetime import datetime + +from Syntool.chem.reaction_rules.extraction import extract_rules_from_reactions +from Syntool.chem.data.cleaning import reactions_cleaner +from Syntool.chem.data.mapping import remove_reagents_and_map_from_file +from Syntool.chem.loading import standardize_building_blocks +from Syntool.ml.training import create_policy_dataset, run_policy_training +from Syntool.ml.training.reinforcement import run_self_tuning +from Syntool.ml.networks.policy import PolicyNetworkConfig +from Syntool.utils.config import read_planning_config, read_training_config, TreeConfig +from Syntool.mcts.search import tree_search + +from Syntool.chem.data.filtering import ( + filter_reactions, + ReactionCheckConfig, + CCRingBreakingConfig, + WrongCHBreakingConfig, + CCsp3BreakingConfig, + DynamicBondsConfig, + MultiCenterConfig, + NoReactionConfig, + SmallMoleculesConfig, +) + +warnings.filterwarnings("ignore") +main = click.Group() + + +@main.command(name='planning_data') +def planning_data_cli(): + """ + Downloads a file from Google Drive using its remote ID, saves it as a zip file, extracts the contents, + and then deletes the zip file + """ + remote_id = '1c5YJDT-rP1ZvFA-ELmPNTUj0b8an4yFf' + output = 'synto_planning_data.zip' + # + gdown.download(output=output, id=remote_id, quiet=True) + shutil.unpack_archive(output, './') + # + os.remove(output) + + +@main.command(name='training_data') +def training_data_cli(): + """ + Downloads a file from Google Drive using its remote ID, saves it as a zip file, extracts the contents, + and then deletes the zip file + """ + remote_id = '1r4I7OskGvzg-zxYNJ7WVYpVR2HSYW10N' + output = 'synto_training_data.zip' + # + gdown.download(output=output, id=remote_id, quiet=True) + shutil.unpack_archive(output, './') + # + os.remove(output) + + +@main.command(name='syntool_planning') +@click.option("--config", + "config_path", + required=True, + help="Path to the config YAML molecules_path.", + type=click.Path(exists=True, path_type=Path), + ) +def syntool_planning_cli(config_path): + """ + Launches tree search for the given target molecules and stores search statistics and found retrosynthetic paths + + :param config_path: The path to the configuration file that contains the settings and parameters for the tree search + """ + config = read_planning_config(config_path) + config['Tree']['silent'] = True + + # standardize building blocks + if config['InputData']['standardize_building_blocks']: + print('STANDARDIZE BUILDING BLOCKS ...') + standardize_building_blocks(config['InputData']['building_blocks_path'], + config['InputData']['building_blocks_path']) + # run planning + print('\nRUN PLANNING ...') + tree_config = TreeConfig.from_dict(config['Tree']) + tree_search(targets=config['General']['targets_path'], + tree_config=tree_config, + reaction_rules_path=config['InputData']['reaction_rules_path'], + building_blocks_path=config['InputData']['building_blocks_path'], + policy_weights_path=config['PolicyNetwork']['weights_path'], + value_weights_paths=config['ValueNetwork']['weights_path'], + results_root=config['General']['results_root']) + + +@main.command(name='syntool_training') +@click.option( + "--config", + "config_path", + required=True, + help="Path to the config YAML file.", + type=click.Path(exists=True, path_type=Path) + ) +def syntool_training_cli(config_path): + + # read training config + print('READ CONFIG ...') + config = read_training_config(config_path) + print('Config is read') + + reaction_data_file = config['InputData']['reaction_data_path'] + + # reaction data mapping + startTime0 = datetime.now() + data_output_folder = os.path.join(config['General']['results_root'], 'reaction_data') + Path(data_output_folder).mkdir(parents=True, exist_ok=True) + mapped_data_file = os.path.join(data_output_folder, 'reaction_data_mapped.smi') + if config['DataCleaning']['map_reactions']: + print('\nMAP REACTION DATA ...') + + remove_reagents_and_map_from_file(input_file=config['InputData']['reaction_data_path'], + output_file=mapped_data_file) + + reaction_data_file = mapped_data_file + print("remove_reagents_and_map_from_file:", datetime.now() - startTime0) + + # reaction data cleaning + startTime0 = datetime.now() + cleaned_data_file = os.path.join(data_output_folder, 'reaction_data_cleaned.rdf') + if config['DataCleaning']['clean_reactions']: + print('\nCLEAN REACTION DATA ...') + + reactions_cleaner(input_file=reaction_data_file, + output_file=cleaned_data_file, + num_cpus=config['General']['num_cpus']) + + reaction_data_file = cleaned_data_file + print("reactions_cleaner:", datetime.now() - startTime0) + + # reactions data filtering + startTime0 = datetime.now() + if config['DataCleaning']['filter_reactions']: + print('\nFILTER REACTION DATA ...') + # + filtration_config = ReactionCheckConfig( + remove_small_molecules=False, + small_molecules_config=SmallMoleculesConfig(limit=6), + dynamic_bonds_config=DynamicBondsConfig(min_bonds_number=1, max_bonds_number=6), + no_reaction_config=NoReactionConfig(), + multi_center_config=MultiCenterConfig(), + wrong_ch_breaking_config=WrongCHBreakingConfig(), + cc_sp3_breaking_config=CCsp3BreakingConfig(), + cc_ring_breaking_config=CCRingBreakingConfig() + ) + + filtered_data_file = os.path.join(data_output_folder, 'reaction_data_filtered.rdf') + filter_reactions(config=filtration_config, + reaction_database_path=reaction_data_file, + result_directory_path=data_output_folder, + result_reactions_file_name='reaction_data_filtered', + num_cpus=config['General']['num_cpus'], + batch_size=100) + + reaction_data_file = filtered_data_file + print("filter_reactions:", datetime.now() - startTime0) + + # standardize building blocks + startTime0 = datetime.now() + if config['DataCleaning']['standardize_building_blocks']: + print('\nSTANDARDIZE BUILDING BLOCKS ...') + + standardize_building_blocks(config['InputData']['building_blocks_path'], + config['InputData']['building_blocks_path']) + print("standardize_building_blocks:", datetime.now() - startTime0) + + # reaction rules extraction + startTime0 = datetime.now() + print('\nEXTRACT REACTION RULES ...') + + rules_output_folder = os.path.join(config['General']['results_root'], 'reaction_rules') + Path(rules_output_folder).mkdir(parents=True, exist_ok=True) + reaction_rules_path = os.path.join(rules_output_folder, 'reaction_rules_filtered.pickle') + config['InputData']['reaction_rules_path'] = reaction_rules_path + + extract_rules_from_reactions(config=config, + reaction_file=reaction_data_file, + results_root=rules_output_folder, + num_cpus=config['General']['num_cpus']) + print("extract_rules_from_reactions:", datetime.now() - startTime0) + + # create policy network dataset + startTime0 = datetime.now() + print('\nCREATE POLICY NETWORK DATASET ...') + policy_output_folder = os.path.join(config['General']['results_root'], 'policy_network') + Path(policy_output_folder).mkdir(parents=True, exist_ok=True) + policy_data_file = os.path.join(policy_output_folder, 'policy_dataset.pt') + + if config['PolicyNetwork']['policy_type'] == 'ranking': + molecules_or_reactions_path = reaction_data_file + elif config['PolicyNetwork']['policy_type'] == 'filtering': + molecules_or_reactions_path = config['InputData']['policy_data_path'] + else: + raise ValueError( + "Invalid policy_type. Allowed values are 'ranking', 'filtering'." + ) + + datamodule = create_policy_dataset(reaction_rules_path=reaction_rules_path, + molecules_or_reactions_path=molecules_or_reactions_path, + output_path=policy_data_file, + dataset_type=config['PolicyNetwork']['policy_type'], + batch_size=config['PolicyNetwork']['batch_size'], + num_cpus=config['General']['num_cpus']) + print("datamodule:", datetime.now() - startTime0) + + # train policy network + startTime0 = datetime.now() + print('\nTRAIN POLICY NETWORK ...') + policy_config = PolicyNetworkConfig.from_dict(config['PolicyNetwork']) + run_policy_training(datamodule, config=policy_config, results_path=policy_output_folder) + config['PolicyNetwork']['weights_path'] = os.path.join(policy_output_folder, 'weights', 'policy_network.ckpt') + print("run_policy_training:", datetime.now() - startTime0) + + # self-tuning value network training + startTime0 = datetime.now() + print('\nTRAIN VALUE NETWORK ...') + value_output_folder = os.path.join(config['General']['results_root'], 'value_network') + Path(value_output_folder).mkdir(parents=True, exist_ok=True) + config['ValueNetwork']['weights_path'] = os.path.join(value_output_folder, 'weights', 'value_network.ckpt') + + run_self_tuning(config, results_root=value_output_folder) + print("run_self_tuning:", datetime.now() - startTime0) + + +if __name__ == '__main__': + main() diff --git a/SynTool/interfaces/visualisation.py b/SynTool/interfaces/visualisation.py new file mode 100755 index 0000000000000000000000000000000000000000..e82ecf851c16afab365cb7444f3e1ae32ad47797 --- /dev/null +++ b/SynTool/interfaces/visualisation.py @@ -0,0 +1,346 @@ +""" +Module containing functions for analysis and visualization of the built search tree +""" + +from itertools import count, islice + +from CGRtools.containers import MoleculeContainer + +from SynTool import Tree +from SynTool.utils import path_type + + +def get_child_nodes(tree, molecule, graph): + nodes = [] + try: + graph[molecule] + except KeyError: + return [] + for retron in graph[molecule]: + temp_obj = { + "smiles": str(retron), + "type": "mol", + "in_stock": str(retron) in tree.building_blocks, + } + node = get_child_nodes(tree, retron, graph) + if node: + temp_obj["children"] = [node] + nodes.append(temp_obj) + return {"type": "reaction", "children": nodes} + + +def extract_routes(tree, extended=False): + """ + The function takes the target and the dictionary of + successors and predecessors and returns a list of dictionaries that contain the target + and the list of successors + :return: A list of dictionaries. Each dictionary contains a target, a list of children, and a + boolean indicating whether the target is in building_blocks. + """ + target = tree.nodes[1].retrons_to_expand[0].molecule + target_in_stock = tree.nodes[1].curr_retron.is_building_block(tree.building_blocks) + # Append encoded routes to list + paths_block = [] + winning_nodes = [] + if extended: + # Gather paths + for i, node in tree.nodes.items(): + if node.is_solved(): + winning_nodes.append(i) + else: + winning_nodes = tree.winning_nodes + if winning_nodes: + for winning_node in winning_nodes: + # Create graph for route + nodes = tree.path_to_node(winning_node) + graph, pred = {}, {} + for before, after in zip(nodes, nodes[1:]): + before = before.curr_retron.molecule + graph[before] = after = [x.molecule for x in after.new_retrons] + for x in after: + pred[x] = before + + paths_block.append({"type": "mol", "smiles": str(target), + "in_stock": target_in_stock, + "children": [get_child_nodes(tree, target, graph)]}) + else: + paths_block = [{"type": "mol", "smiles": str(target), "in_stock": target_in_stock, "children": []}] + return paths_block + + +def path_graph(tree, node: int) -> str: + """ + Visualizes reaction path + + :param node: node id + :type node: int + :return: The SVG string. + """ + nodes = tree.path_to_node(node) + # Set up node_id types for different box colors + for node in nodes: + for retron in node.new_retrons: + retron._molecule.meta["status"] = "instock" if retron.is_building_block( + tree.building_blocks) else "mulecule" + nodes[0].curr_retron._molecule.meta["status"] = "target" + # Box colors + box_colors = {"target": "#98EEFF", # 152, 238, 255 + "mulecule": "#F0AB90", # 240, 171, 144 + "instock": "#9BFAB3", # 155, 250, 179 + } + + # first column is target + # second column are first new retrons_to_expand + columns = [[nodes[0].curr_retron.molecule], [x.molecule for x in nodes[1].new_retrons], ] + pred = {x: 0 for x in range(1, len(columns[1]) + 1)} + cx = [n for n, x in enumerate(nodes[1].new_retrons, 1) if not x.is_building_block(tree.building_blocks)] + size = len(cx) + nodes = iter(nodes[2:]) + cy = count(len(columns[1]) + 1) + while size: + layer = [] + for s in islice(nodes, size): + n = cx.pop(0) + for x in s.new_retrons: + layer.append(x) + m = next(cy) + if not x.is_building_block(tree.building_blocks): + cx.append(m) + pred[m] = n + size = len(cx) + columns.append([x.molecule for x in layer]) + + columns = [columns[::-1] for columns in columns[::-1]] # Reverse array to make retrosynthetic graph + pred = tuple( # Change dict to tuple to make multiple retrons_to_expand available + (abs(source - len(pred)), abs(target - len(pred))) for target, source in pred.items()) + + # now we have columns for visualizing + # lets start recalculate XY + x_shift = 0.0 + c_max_x = 0.0 + c_max_y = 0.0 + render = [] + cx = count() + cy = count() + arrow_points = {} + for ms in columns: + heights = [] + for m in ms: + m.clean2d() + # X-shift for target + min_x = min(x for x, y in m._plane.values()) - x_shift + min_y = min(y for x, y in m._plane.values()) + m._plane = {n: (x - min_x, y - min_y) for n, (x, y) in m._plane.items()} + max_x = max(x for x, y in m._plane.values()) + if max_x > c_max_x: + c_max_x = max_x + arrow_points[next(cx)] = [x_shift, max_x] + heights.append(max(y for x, y in m._plane.values())) + + x_shift = c_max_x + 5.0 # between columns gap + # calculate Y-shift + y_shift = sum(heights) + 3.0 * (len(heights) - 1) + if y_shift > c_max_y: + c_max_y = y_shift + y_shift /= 2.0 + for m, h in zip(ms, heights): + m._plane = {n: (x, y - y_shift) for n, (x, y) in m._plane.items()} + + # Calculate coordinates for boxes + max_x = max(x for x, y in m._plane.values()) + 0.9 # Max x + min_x = min(x for x, y in m._plane.values()) - 0.6 # Min x + max_y = -(max(y for x, y in m._plane.values()) + 0.45) # Max y + min_y = -(min(y for x, y in m._plane.values()) - 0.45) # Min y + x_delta = abs(max_x - min_x) + y_delta = abs(max_y - min_y) + box = ( + f'') + arrow_points[next(cy)].append(y_shift - h / 2.0) + y_shift -= h + 3.0 + depicted_molecule = list(m.depict(embedding=True))[:3] + depicted_molecule.append(box) + render.append(depicted_molecule) + + # Calculate mid-X coordinate to draw square arrows + graph = {} + for s, p in pred: + try: + graph[s].append(p) + except KeyError: + graph[s] = [p] + for s, ps in graph.items(): + mid_x = float("-inf") + for p in ps: + s_min_x, s_max, s_y = arrow_points[s][:3] # s + p_min_x, p_max, p_y = arrow_points[p][:3] # p + p_max += 1 + mid = p_max + (s_min_x - p_max) / 3 + if mid > mid_x: + mid_x = mid + for p in ps: + arrow_points[p].append(mid_x) + + config = MoleculeContainer._render_config + font_size = config["font_size"] + font125 = 1.25 * font_size + width = c_max_x + 4.0 * font_size # 3.0 by default + height = c_max_y + 3.5 * font_size # 2.5 by default + box_y = height / 2.0 + svg = [f'', + ' \n \n \n \n ', ] + + for s, p in pred: + """ + (x1, y1) = (p_max, p_y) + (x2, y2) = (s_min_x, s_y) + polyline: (x1 y1, x2 y2, x3 y3, ..., xN yN) + """ + s_min_x, s_max, s_y = arrow_points[s][:3] + p_min_x, p_max, p_y = arrow_points[p][:3] + p_max += 1 + mid_x = arrow_points[p][-1] # p_max + (s_min_x - p_max) / 3 + """print(f"s_min_x: {s_min_x}, s_max: {s_max}, s_y: {s_y}") + print(f"p_min_x: {p_min_x}, p_max: {p_max}, p_y: {p_y}") + print(f"mid_x: {mid_x}\n")""" + + arrow = f""" """ + if p_y != s_y: + arrow += f' ' + svg.append(arrow) + for atoms, bonds, masks, box in render: + molecule_svg = MoleculeContainer._graph_svg(atoms, bonds, masks, -font125, -box_y, width, height) + molecule_svg.insert(1, box) + svg.extend(molecule_svg) + svg.append("") + return "\n".join(svg) + + +def to_table(tree: Tree, html_path: path_type, aam: bool = False, extended=False, integration: bool = False): + """ + Write an HTML page with the synthesis paths in SVG format and corresponding reactions in SMILES format + + :param tree: # TODO + :param extended: # TODO + :param html_path: Path to save the HTML molecules_path, if None returns the html without saving it + :type html_path: str (optional) + :param aam: depict atom-to-atom mapping + :type aam: bool (optional) + :param integration: Whenever to output the full html file (False) or only the body (True) + :type integration: bool + """ + if aam: + MoleculeContainer.depict_settings(aam=True) + else: + MoleculeContainer.depict_settings(aam=False) + + paths = [] + if extended: + # Gather paths + for idx, node in tree.nodes.items(): + if node.is_solved(): + paths.append(idx) + else: + paths = tree.winning_nodes + # HTML Tags + th = '' + td = '' + font_red = "" + font_green = "" + font_head = "" + font_normal = "" + font_close = "" + + template_begin = """ + + + + + + + + Predicted Paths Report + + + + + """ + template_end = """ + + + """ + # SVG Template + # box_mark = """ + # + # + # + # """ + # table = f"<{th}>Retrosynthetic Routes" + table = """
""" + if not integration: + table += "" + else: + table += "" + + # Gather path data + table += f"{td}{font_normal}Target Molecule: {str(tree.nodes[1].curr_retron)}{font_close}" + table += (f"{td}{font_normal}Tree Size: {len(tree)}{font_close} nodes") + table += f"{td}{font_normal}Number of visited nodes: {len(tree.visited_nodes)}{font_close}" + table += f"{td}{font_normal}Found paths: {len(paths)}{font_close}" + table += f"{td}{font_normal}Time: {round(tree.curr_time, 4)}{font_close} seconds" + table += f"""\ + {td} \ + \ + \ + Target Molecule \ + \ + \ + \ + Molecule Not In Stock \ + \ + \ + \ + Molecule In Stock \ + \ + \ + """ + + for path in paths: + svg = path_graph(tree, path) # Get SVG + full_path = tree.synthesis_path(path) # Get Path + # Write SMILES of all reactions in synthesis path + step = 1 + reactions = "" + for synth_step in full_path: + reactions += f"Step {step}: {str(synth_step)}
" + step += 1 + # Concatenate all content of path + path_score = round(tree.path_score(path), 3) + table += (f'{td}{font_head}Path {path}; ' + f"Steps: {len(full_path)}; " + f"Cumulated nodes' value: {path_score}{font_close}") + # f"Cumulated nodes' value: {node._probabilities[path]}{font_close}" + table += f"{td}{svg}" + table += f"{td}{reactions}" + table += "" + + # Save or display output + if not html_path: + return table if integration else template_begin + table + template_end + + output = html_path + with open(output, "w") as html_file: + html_file.write(template_begin) + html_file.write(table) + html_file.write(template_end) diff --git a/SynTool/mcts/__init__.py b/SynTool/mcts/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..8631dd28191f90605a073dab5165c61dd1edee5e --- /dev/null +++ b/SynTool/mcts/__init__.py @@ -0,0 +1,7 @@ +from .node import * +from .tree import * +from CGRtools.containers import MoleculeContainer + +MoleculeContainer.depict_settings(aam=False) + +__all__ = ["Tree", "Node"] diff --git a/SynTool/mcts/__pycache__/__init__.cpython-310.pyc b/SynTool/mcts/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..616653fdf6615092a2676fe2da8c36a5ccb6e86c Binary files /dev/null and b/SynTool/mcts/__pycache__/__init__.cpython-310.pyc differ diff --git a/SynTool/mcts/__pycache__/evaluation.cpython-310.pyc b/SynTool/mcts/__pycache__/evaluation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4532a6a09a5f2c4413abdd5de1418d66d53701be Binary files /dev/null and b/SynTool/mcts/__pycache__/evaluation.cpython-310.pyc differ diff --git a/SynTool/mcts/__pycache__/expansion.cpython-310.pyc b/SynTool/mcts/__pycache__/expansion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..279c959c846ba1cb37e7b98998f4cae78dfe5136 Binary files /dev/null and b/SynTool/mcts/__pycache__/expansion.cpython-310.pyc differ diff --git a/SynTool/mcts/__pycache__/node.cpython-310.pyc b/SynTool/mcts/__pycache__/node.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaf431499f42621f23d89931c20239159361fedf Binary files /dev/null and b/SynTool/mcts/__pycache__/node.cpython-310.pyc differ diff --git a/SynTool/mcts/__pycache__/search.cpython-310.pyc b/SynTool/mcts/__pycache__/search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cafc6ce38d51bfdc513e5655b41967828479702 Binary files /dev/null and b/SynTool/mcts/__pycache__/search.cpython-310.pyc differ diff --git a/SynTool/mcts/__pycache__/tree.cpython-310.pyc b/SynTool/mcts/__pycache__/tree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..106471376e5419c3a48965a104e71d9ec8d847ac Binary files /dev/null and b/SynTool/mcts/__pycache__/tree.cpython-310.pyc differ diff --git a/SynTool/mcts/evaluation.py b/SynTool/mcts/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..16919c8a04e4e993e116719f94b257470c2f67d5 --- /dev/null +++ b/SynTool/mcts/evaluation.py @@ -0,0 +1,59 @@ +""" +Module containing a class that represents a value function for prediction of synthesisablity +of new nodes in the search tree +""" + +import logging +import torch + +from pathlib import Path + +from SynTool.chem.retron import compose_retrons +from SynTool.ml.networks.value import ValueNetwork +from SynTool.ml.training import mol_to_pyg + + +class ValueFunction: + """ + Value function based on value neural network for node evaluation (synthesisability prediction) in MCTS + """ + + def __init__(self, weights_path: str) -> None: + """ + The value function predicts the probability to synthesize the target molecule with available building blocks + starting from a given retron. + + :param weights_path: The value network weights location + :type weights_path: Path + """ + + value_net = ValueNetwork.load_from_checkpoint( + weights_path, + map_location=torch.device("cpu") + ) + + self.value_network = value_net.eval() + + def predict_value(self, retrons: list) -> float: + """ + The function predicts a value based on the given retrons. For prediction, retrons must be composed into a single + molecule (product) + + :param retrons: The list of retrons + :type retrons: list + """ + + molecule = compose_retrons(retrons=retrons, exclude_small=True) + pyg_graph = mol_to_pyg(molecule) + if pyg_graph: + with torch.no_grad(): + value_pred = self.value_network.forward(pyg_graph)[0].item() + else: + try: + logging.debug(f"Molecule {str(molecule)} was not preprocessed. Giving value equal to -1e6.") + except: + logging.debug(f"There is a molecule for which SMILES cannot be generated") + + value_pred = -1e6 + + return value_pred diff --git a/SynTool/mcts/expansion.py b/SynTool/mcts/expansion.py new file mode 100644 index 0000000000000000000000000000000000000000..082b5e149fc614aa0300fb28e79f763d137316b1 --- /dev/null +++ b/SynTool/mcts/expansion.py @@ -0,0 +1,83 @@ +""" +Module containing a class that represents a policy function for node expansion in the search tree +""" + +import torch +import torch_geometric +from SynTool.chem.retron import Retron +from SynTool.ml.networks.policy import PolicyNetwork +from SynTool.ml.training import mol_to_pyg +from SynTool.utils.config import PolicyNetworkConfig + + +class PolicyFunction: + """ + Policy function based on policy neural network for node expansion in MCTS + """ + + def __init__(self, policy_config: PolicyNetworkConfig, compile: bool = False): + """ + Initializes the expansion function (ranking or filter policy network). + + :param policy_config: A configuration object settings for the expansion policy + :type policy_config: PolicyConfig + :param compile: XX # TODO what is compile # TODO2 compile is a bad variable name - is a builtin function name + :type compile: bool + """ + + self.config = policy_config + + policy_net = PolicyNetwork.load_from_checkpoint( + self.config.weights_path, + map_location=torch.device("cpu"), + batch_size=1, + dropout=0 + ) + + policy_net = policy_net.eval() + if compile: + self.policy_net = torch_geometric.compile(policy_net, dynamic=True) + else: + self.policy_net = policy_net + + def predict_reaction_rules(self, retron: Retron, reaction_rules: list): # TODO what is output - finish annotation + """ + The policy function predicts the list of reaction rules given a retron. + + :param retron: The current retron for which the reaction rules are predicted + :type retron: Retron + :param reaction_rules: The list of reaction rules from which applicable reaction rules are predicted and selected. + :type reaction_rules: list + """ + + pyg_graph = mol_to_pyg(retron.molecule, canonicalize=False) + if pyg_graph: + with torch.no_grad(): + if self.policy_net.policy_type == "filtering": + probs, priority = self.policy_net.forward(pyg_graph) + if self.policy_net.policy_type == "ranking": + probs = self.policy_net.forward(pyg_graph) + del pyg_graph + else: + return [] + + probs = probs[0].double() + if self.policy_net.policy_type == "filtering": + priority = priority[0].double() + priority_coef = self.config.priority_rules_fraction + probs = (1 - priority_coef) * probs + priority_coef * priority + + sorted_probs, sorted_rules = torch.sort(probs, descending=True) + sorted_probs, sorted_rules = ( + sorted_probs[: self.config.top_rules], + sorted_rules[: self.config.top_rules], + ) + + if self.policy_net.policy_type == "filtering": + sorted_probs = torch.softmax(sorted_probs, -1) + + sorted_probs, sorted_rules = sorted_probs.tolist(), sorted_rules.tolist() + + for prob, rule_id in zip(sorted_probs, sorted_rules): + if prob > self.config.rule_prob_threshold: # TODO it will destroy all search if it is not correct (>0.5) + yield prob, reaction_rules[rule_id], rule_id diff --git a/SynTool/mcts/node.py b/SynTool/mcts/node.py new file mode 100755 index 0000000000000000000000000000000000000000..b6f5c3e80ed8a3252f6967a0d7d6cd9bf1261f40 --- /dev/null +++ b/SynTool/mcts/node.py @@ -0,0 +1,49 @@ +""" +Module containing a class Node that represents a node in the search tree +""" + + +class Node: + """ + Node class represents a node in the search tree + """ + + def __init__(self, retrons_to_expand: tuple = None, new_retrons: tuple = None) -> None: + """ + The function initializes the new Node object. + + :param retrons_to_expand: The tuple of retrons to be expanded. The first retron in the tuple is the current + retron which will be expanded (for which new retrons will be generated by applying the predicted reaction + rules). When the first retron has been successfully expanded, the second retron becomes the current retron + to be expanded. + :param new_retrons: The tuple of new retrons generated by applying the reaction rule. New retrons have already + been added to the retrons_to_expand (see Tree._expand_node). Here they are stored for information. + """ + + self.retrons_to_expand = retrons_to_expand + self.new_retrons = new_retrons + + if len(self.retrons_to_expand) == 0: + self.curr_retron = tuple() + else: + self.curr_retron = self.retrons_to_expand[0] + self.next_retrons = self.retrons_to_expand[1:] + + def __len__(self) -> int: + """ + The number of retrons in this node to expand. + """ + return len(self.retrons_to_expand) + + def __repr__(self) -> str: + """ + String representation of the node. Returns the smiles of retrons_to_expand and new_retrons. + """ + return f"retrons_to_expand: {self.retrons_to_expand}\nnew_retrons: {self.new_retrons}" + + def is_solved(self) -> bool: + """ + Is terminal node. There are not retrons for expansion. + """ + + return len(self.retrons_to_expand) == 0 diff --git a/SynTool/mcts/search.py b/SynTool/mcts/search.py new file mode 100755 index 0000000000000000000000000000000000000000..2208b394ab2a762ee903aee98199d63a7c649daf --- /dev/null +++ b/SynTool/mcts/search.py @@ -0,0 +1,135 @@ +""" +Module containing functions for running tree search for the set of target molecules +""" + +import csv +import json +from pathlib import Path + +from tqdm import tqdm + +from SynTool.interfaces.visualisation import to_table, extract_routes +from SynTool.mcts.tree import Tree, TreeConfig +from SynTool.mcts.evaluation import ValueFunction +from SynTool.mcts.expansion import PolicyFunction +from SynTool.utils import path_type +from SynTool.utils.files import MoleculeReader +from SynTool.utils.config import PolicyNetworkConfig + + +def extract_tree_stats(tree, target): + """ + Collects various statistics from a tree and returns them in a dictionary format + + :param tree: The retro tree. + :param target: The target molecule or compound that you want to search for in the tree. It is + expected to be a string representing the SMILES notation of the target molecule + :return: A dictionary with the calculated statistics + """ + newick_tree, newick_meta = tree.newickify(visits_threshold=0) + newick_meta_line = ";".join([f"{nid},{v[0]},{v[1]},{v[2]}" for nid, v in newick_meta.items()]) + return { + "target_smiles": str(target), + "tree_size": len(tree), + "search_time": round(tree.curr_time, 1), + "found_paths": len(tree.winning_nodes), + "newick_tree": newick_tree, + "newick_meta": newick_meta_line, + } + + +def tree_search( + targets_path: path_type, + tree_config: TreeConfig, + policy_config: PolicyNetworkConfig, + reaction_rules_path: path_type, + building_blocks_path: path_type, + policy_weights_path: path_type = None, # TODO not used + value_weights_path: path_type = None, + results_root: path_type = "search_results" +): + """ + Performs a tree search on a set of target molecules using specified configuration and rules, + logging the results and statistics. + + :param tree_config: The config object containing the configuration for the tree search. + :param policy_config: The config object containing the configuration for the policy. + :param reaction_rules_path: The path to the file containing reaction rules. + :param building_blocks_path: The path to the file containing building blocks. + :param targets_path: The path to the file containing the target molecules (in SDF or SMILES format). + :param value_weights_path: The path to the file containing value weights (optional). + :param results_root: The path to the directory where the results of the tree search will be saved. Defaults to 'search_results/'. + :param retropaths_files_name: The base name for the files that will be generated to store the retro paths. Defaults to 'retropath'. #TODO arg dont exist + + This function configures and executes a tree search algorithm, leveraging reaction rules and building blocks + to find synthetic pathways for given target molecules. The results, including paths and statistics, are + saved in the specified directory. Logging is used to record the process and any issues encountered. + """ + + targets_file = Path(targets_path) + + # results folder + results_root = Path(results_root) + if not results_root.exists(): + results_root.mkdir() + + # output files + stats_file = results_root.joinpath("tree_search_stats.csv") + paths_file = results_root.joinpath("extracted_paths.json") + retropaths_folder = results_root.joinpath("retropaths") + retropaths_folder.mkdir(exist_ok=True) + + # stats header + stats_header = ["target_smiles", "tree_size", "search_time", + "found_paths", "newick_tree", "newick_meta"] + + # config + policy_function = PolicyFunction(policy_config=policy_config) + if tree_config.evaluation_type == 'gcn': + value_function = ValueFunction(weights_path=value_weights_path) + else: + value_function = None + + # run search + n_solved = 0 + extracted_paths = [] + with MoleculeReader(targets_file) as targets_path, open(stats_file, "w", newline="\n") as csvfile: + statswriter = csv.DictWriter(csvfile, delimiter=",", fieldnames=stats_header) + statswriter.writeheader() + + for ti, target in tqdm(enumerate(targets_path), total=len(targets_path)): + + try: + # run search + tree = Tree( + target=target, + tree_config=tree_config, + reaction_rules_path=reaction_rules_path, + building_blocks_path=building_blocks_path, + policy_function=policy_function, + value_function=value_function, + ) + _ = list(tree) + + except: + continue + + n_solved += bool(tree.winning_nodes) + + # extract routes + extracted_paths.append(extract_routes(tree)) + + # retropaths + retropaths_file = retropaths_folder.joinpath(f"retropaths_target_{ti}.html") + to_table(tree, retropaths_file, extended=True) + + # stats + statswriter.writerow(extract_tree_stats(tree, target)) + csvfile.flush() + + # + with open(paths_file, 'w') as f: + json.dump(extracted_paths, f) + + print(f"Solved number of target molecules: {n_solved}") + diff --git a/SynTool/mcts/tree.py b/SynTool/mcts/tree.py new file mode 100755 index 0000000000000000000000000000000000000000..80e7c5d9c746dbc0dff3ae9401501349d82dd532 --- /dev/null +++ b/SynTool/mcts/tree.py @@ -0,0 +1,659 @@ +""" +Module containing a class Tree that used for tree search of retrosynthetic paths +""" + +import logging +from collections import deque, defaultdict +from math import sqrt +from random import choice, uniform +from time import time +from typing import Dict, Set, List, Tuple + +from CGRtools.containers import MoleculeContainer +from CGRtools import smiles +from numpy.random import uniform +from tqdm.auto import tqdm +from SynTool.utils.loading import load_building_blocks, load_reaction_rules +from SynTool.chem.reaction import Reaction, apply_reaction_rule +from SynTool.chem.retron import Retron +from SynTool.mcts.evaluation import ValueFunction +from SynTool.mcts.expansion import PolicyFunction +from SynTool.mcts.node import Node +from SynTool.utils.config import TreeConfig + + +class Tree: + """ + Tree class with attributes and methods for Monte-Carlo tree search + """ + + def __init__( + self, + target: MoleculeContainer, + tree_config: TreeConfig, + reaction_rules_path: str, + building_blocks_path: str, + policy_function: PolicyFunction, + value_function: ValueFunction = None, + ): + """ + The function initializes a tree object with optional parameters for tree search for target molecule. + + :param target: a target molecule for retrosynthesis paths search + :type target: MoleculeContainer + :param tree_config: a tree configuration file for retrosynthesis paths search + :type tree_config: TreeConfig + :param reaction_rules_path: a path for reaction rules file + :type reaction_rules_path: str + :param building_blocks_path: a path for building blocks file + :type building_blocks_path: str + :param policy_function: a policy function object + :type policy_function: PolicyFunction + :param value_function: a value function object + :type value_function: ValueFunction + """ + + # config parameters + self.config = tree_config + + # check target + if isinstance(target, str): + target = smiles(target) + assert (bool(target)), "Target is not defined, is not a MoleculeContainer or have no atoms" + if target: + target.canonicalize() + + target_retron = Retron(target, canonicalize=True) + target_retron.prev_retrons.append(Retron(target, canonicalize=True)) + target_node = Node(retrons_to_expand=(target_retron,), new_retrons=(target_retron,)) + + # tree structure init + self.nodes: Dict[int, Node] = {1: target_node} + self.parents: Dict[int, int] = {1: 0} + self.children: Dict[int, Set[int]] = {1: set()} + self.winning_nodes: List[int] = list() + self.visited_nodes: Set[int] = set() + self.expanded_nodes: Set[int] = set() + self.nodes_visit: Dict[int, int] = {1: 0} + self.nodes_depth: Dict[int, int] = {1: 0} + self.nodes_prob: Dict[int, float] = {1: 0.0} + self.nodes_init_value: Dict[int, float] = {1: 0.0} + self.nodes_total_value: Dict[int, float] = {1: 0.0} + + # tree building limits + self.curr_iteration: int = 0 + self.curr_tree_size: int = 2 + self.curr_time: float = 2 + + # utils + self._tqdm = None + + # policy and value functions + self.policy_function = policy_function + if self.config.evaluation_type == "gcn": + if value_function is None: + raise ValueError( + "Value function not specified while evaluation mode is 'gcn'" + ) + else: + self.value_function = value_function + + # building blocks and reaction reaction_rules + self.reaction_rules = load_reaction_rules(reaction_rules_path) + self.building_blocks = load_building_blocks(building_blocks_path) + + def __len__(self) -> int: + """ + Returns the current size (number of nodes) of a Tree. + """ + + return self.curr_tree_size - 1 + + def __iter__(self) -> "Tree": # TODO what is annotation "Tree" -> Tree ? + """ + The function is defining an iterator for a Tree object. Also needed for the bar progress display. + """ + + if not self._tqdm: + self._start_time = time() + self._tqdm = tqdm( + total=self.config.max_iterations, disable=self.config.silent + ) + return self + + def __repr__(self) -> str: + """ + Returns a string representation of a Tree object (target smiles, tree size, and the number of found paths). + """ + return self.report() + + def __next__(self): # TODO what is return - function annotation ? tuple (bool, [node id]) + """ + The __next__ function is used to do one iteration of the tree building. + """ + + if self.nodes[1].curr_retron.is_building_block(self.building_blocks, self.config.min_mol_size): + raise StopIteration("Target is building block \n") + + if self.curr_iteration >= self.config.max_iterations: + self._tqdm.close() + raise StopIteration("Iterations limit exceeded. \n") + elif self.curr_tree_size >= self.config.max_tree_size: + self._tqdm.close() + raise StopIteration("Max tree size exceeded or all possible paths found") + elif self.curr_time >= self.config.max_time: + self._tqdm.close() + raise StopIteration("Time limit exceeded. \n") + + # start new iteration + self.curr_iteration += 1 + self.curr_time = time() - self._start_time + self._tqdm.update() + + curr_depth, node_id = 0, 1 # start from the root node_id + + explore_path = True + while explore_path: + self.visited_nodes.add(node_id) + + if self.nodes_visit[node_id]: # already visited + if not self.children[node_id]: # dead node + logging.debug( + f"Tree search: bumped into node {node_id} which is dead" + ) + self._update_visits(node_id) + explore_path = False + else: + node_id = self._select_node(node_id) # select the child node + curr_depth += 1 + else: + if self.nodes[node_id].is_solved(): # found path! + self._update_visits(node_id) # this prevents expanding of bb node_id + self.winning_nodes.append(node_id) + return True, [node_id] + + elif ( + curr_depth < self.config.max_depth + ): # expand node if depth limit is not reached + self._expand_node(node_id) + if not self.children[node_id]: # node was not expanded + logging.debug(f"Tree search: node {node_id} was not expanded") + value_to_backprop = -1.0 + else: + self.expanded_nodes.add(node_id) + + if self.config.search_strategy == "evaluation_first": + # recalculate node value based on children synthesisability and backpropagation + child_values = [ + self.nodes_init_value[child_id] + for child_id in self.children[node_id] + ] + + if self.config.evaluation_agg == "max": + value_to_backprop = max(child_values) + + elif self.config.evaluation_agg == "average": + value_to_backprop = sum(child_values) / len( + self.children[node_id] + ) + + else: + raise ValueError( + f"Invalid evaluation aggregation mode: {self.config.evaluation_agg} " + f"Allowed values are 'max', 'average'" + ) + elif self.config.search_strategy == "expansion_first": + value_to_backprop = self._get_node_value(node_id) + + else: + raise ValueError( + f"Invalid search_strategy: {self.config.search_strategy}: " + f"Allowed values are 'expansion_first', 'evaluation_first'" + ) + + # backpropagation + self._backpropagate(node_id, value_to_backprop) + self._update_visits(node_id) + explore_path = False + + if self.children[node_id]: + # found after expansion + found_after_expansion = set() + for child_id in iter(self.children[node_id]): + if self.nodes[child_id].is_solved(): + found_after_expansion.add(child_id) + self.winning_nodes.append(child_id) + + if found_after_expansion: + return True, list(found_after_expansion) + + else: + self._backpropagate(node_id, self.nodes_total_value[node_id]) + self._update_visits(node_id) + explore_path = False + + return False, [node_id] + + def _ucb(self, node_id: int) -> float: + """ + The function calculates the Upper Confidence Bound (UCB) for a given node. + + :param node_id: The `node_id` parameter is an integer that represents the ID of a node in a tree + :type node_id: int + """ + + prob = self.nodes_prob[node_id] # Predicted by policy network score + visit = self.nodes_visit[node_id] + + if self.config.ucb_type == "puct": + u = ( + self.config.c_ucb * prob * sqrt(self.nodes_visit[self.parents[node_id]]) + ) / (visit + 1) + return self.nodes_total_value[node_id] + u + elif self.config.ucb_type == "uct": + u = ( + self.config.c_ucb + * sqrt(self.nodes_visit[self.parents[node_id]]) + / (visit + 1) + ) + return self.nodes_total_value[node_id] + u + elif self.config.ucb_type == "value": + return self.nodes_init_value[node_id] / (visit + 1) + else: + raise ValueError(f"I don't know this UCB type: {self.config.ucb_type}") + + def _select_node(self, node_id: int) -> int: + """ + This function selects a node based on its UCB value and returns the ID of the node with the highest value of + the UCB function. + + :param node_id: The `node_id` parameter is an integer that represents the ID of a node + :type node_id: int + """ + + if self.config.epsilon > 0: + n = uniform(0, 1) + if n < self.config.epsilon: + return choice(list(self.children[node_id])) + + best_score, best_children = None, [] + for child_id in self.children[node_id]: + score = self._ucb(child_id) + if best_score is None or score > best_score: + best_score, best_children = score, [child_id] + elif score == best_score: + best_children.append(child_id) + return choice(best_children) + + def _expand_node(self, node_id: int) -> None: + """ + The function expands a given node by generating new retrons with policy (expansion) policy. + + :param node_id: The `node_id` parameter is an integer that represents the ID of the current node + :type node_id: int + """ + curr_node = self.nodes[node_id] + prev_retrons = curr_node.curr_retron.prev_retrons + + tmp_retrons = set() + for prob, rule, rule_id in self.policy_function.predict_reaction_rules( + curr_node.curr_retron, self.reaction_rules + ): + for products in apply_reaction_rule(curr_node.curr_retron.molecule, rule): + # check repeated products + if not products or not set(products) - tmp_retrons: + continue + tmp_retrons.update(products) + + for molecule in products: + molecule.meta["reactor_id"] = rule_id + + new_retrons = tuple(Retron(mol) for mol in products) + scaled_prob = prob * len( + list(filter(lambda x: len(x) > self.config.min_mol_size, products)) + ) + + if set(prev_retrons).isdisjoint(new_retrons): + retrons_to_expand = ( + *curr_node.next_retrons, + *( + x + for x in new_retrons + if not x.is_building_block( + self.building_blocks, self.config.min_mol_size + ) + ), + ) + + child_node = Node( + retrons_to_expand=retrons_to_expand, new_retrons=new_retrons + ) + + for new_retron in new_retrons: + new_retron.prev_retrons = [new_retron, *prev_retrons] + + self._add_node(node_id, child_node, scaled_prob) + + def _add_node(self, node_id: int, new_node: Node, policy_prob: float = None) -> None: + """ + This function adds a new node to a tree with its predicted policy probability. + + :param node_id: ID of the parent node + :type node_id: int + :param new_node: The `new_node` is an instance of the`Node` class + :type new_node: Node + :param policy_prob: The `policy_prob` a float value that represents the probability associated with a new node. + :type policy_prob: float + """ + + new_node_id = self.curr_tree_size + + self.nodes[new_node_id] = new_node + self.parents[new_node_id] = node_id + self.children[node_id].add(new_node_id) + self.children[new_node_id] = set() + self.nodes_visit[new_node_id] = 0 + self.nodes_prob[new_node_id] = policy_prob + self.nodes_depth[new_node_id] = self.nodes_depth[node_id] + 1 + self.curr_tree_size += 1 + + if self.config.search_strategy == "evaluation_first": + node_value = self._get_node_value(new_node_id) + elif self.config.search_strategy == "expansion_first": + node_value = self.config.init_node_value + else: + raise ValueError( + f"Invalid search_strategy: {self.config.search_strategy}: " + f"Allowed values are 'expansion_first', 'evaluation_first'" + ) + + self.nodes_init_value[new_node_id] = node_value + self.nodes_total_value[new_node_id] = node_value + + def _get_node_value(self, node_id: int) -> float: + """ + This function calculates the value for the given node. + + :param node_id: ID of the given node + :type node_id: int + """ + + node = self.nodes[node_id] + + if self.config.evaluation_type == "random": + node_value = uniform() + + elif self.config.evaluation_type == "rollout": + node_value = min( + ( + self._rollout_node(retron, current_depth=self.nodes_depth[node_id]) + for retron in node.retrons_to_expand + ), + default=1.0, + ) + + elif self.config.evaluation_type == "gcn": + node_value = self.value_function.predict_value(node.new_retrons) + + else: + raise ValueError( + f"I don't know this evaluation mode: {self.config.evaluation_type}" + ) + + return node_value + + def _update_visits(self, node_id: int) -> None: + """ + The function updates the number of visits from a given node to a root node. + + :param node_id: The ID of a current node + :type node_id: int + """ + + while node_id: + self.nodes_visit[node_id] += 1 + node_id = self.parents[node_id] + + def _backpropagate(self, node_id: int, value: float) -> None: + """ + The function backpropagates a value through a tree of a given node specified by node_id. + + :param node_id: The ID of a given node from which to backpropagate value + :type node_id: int + :param value: The value to backpropagate + :type value: float + """ + while node_id: + if self.config.backprop_type == "muzero": + self.nodes_total_value[node_id] = ( + self.nodes_total_value[node_id] * self.nodes_visit[node_id] + value + ) / (self.nodes_visit[node_id] + 1) + elif self.config.backprop_type == "cumulative": + self.nodes_total_value[node_id] += value + else: + raise ValueError( + f"I don't know this backpropagation type: {self.config.backprop_type}" + ) + node_id = self.parents[node_id] + + def _rollout_node(self, retron: Retron, current_depth: int = None) -> float: + """ + The function `_rollout_node` performs a rollout simulation from a given node in a tree. + Given the current retron, find the first successful reaction and return the new retrons. + + If the retron is a building_block, return 1.0, else check the first successful reaction; + + If the reaction is not successful, return -1.0; + + If the reaction is successful, but the generated retrons are not the building_blocks and the retrons + cannot be generated without exceeding current_depth threshold, return -0.5; + + If the reaction is successful, but the retrons are not the building_blocks and the retrons + cannot be generated, return -1.0; + + :param retron: A Retron object + :type retron: Retron + :param current_depth: The current depth of the tree + :type current_depth: int + """ + + max_depth = self.config.max_depth - current_depth + + # retron checking + if retron.is_building_block(self.building_blocks, self.config.min_mol_size): + return 1.0 + + if max_depth == 0: + logging.debug("Rollout: tried to perform rollout on the leaf node") + return -0.5 + + # retron simulating + occurred_retrons = set() + retrons_to_expand = deque([retron]) + history = defaultdict(dict) + rollout_depth = 0 + while retrons_to_expand: + # Iterate through reactors and pick first successful reaction. + # Check products of the reaction if you can find them in in-building_blocks data + # If not, then add missed products to retrons_to_expand and try to decompose them + if len(history) >= max_depth: + logging.debug( + f"Rollout: max depth of rollout is reached with these " + f"retrons to expand: {retrons_to_expand} {history}", + ) + reward = -0.5 + return reward + + current_retron = retrons_to_expand.popleft() + history[rollout_depth]["target"] = current_retron + occurred_retrons.add(current_retron) + + # Pick the first successful reaction while iterating through reactors + reaction_rule_applied = False + for prob, rule, rule_id in self.policy_function.predict_reaction_rules( + current_retron, self.reaction_rules + ): + for products in apply_reaction_rule(current_retron.molecule, rule): + if products: + reaction_rule_applied = True + break + + if reaction_rule_applied: + history[rollout_depth]["rule_index"] = rule_id + break + + if not reaction_rule_applied: + logging.debug( + f"Rollout: no reaction rule was applied for the " + f"molecule {current_retron} on rollout depth {rollout_depth}" + ) + reward = -1.0 + return reward + + products = tuple(Retron(product) for product in products) # TODO /!\ Is it ok how products is defined above (line 496) ? Seems to + # TODO /!\ consider only last iterable of apply_reaction_rule + history[rollout_depth]["products"] = products + + # check loops + if any(x in occurred_retrons for x in products) and products: + # Sometimes manual can create a loop, when + logging.debug("Rollout: rollout got in the loop: %s", history) + # print('occurred_retrons') + reward = -1.0 + return reward + + if occurred_retrons.isdisjoint(products): + # Added number of atoms check + retrons_to_expand.extend( + [ + x + for x in products + if not x.is_building_block( + self.building_blocks, self.config.min_mol_size + ) + ] + ) + rollout_depth += 1 + + reward = 1.0 + return reward + + def report(self) -> str: + """ + Returns the string representation of the tree. + """ + + return ( + f"Tree for: {str(self.nodes[1].retrons_to_expand[0])}\n" + f"Number of nodes: {len(self)}\nNumber of visited nodes: {len(self.visited_nodes)}\n" + f"Found paths: {len(self.winning_nodes)}\nTime: {round(self.curr_time, 1)} seconds" + ) + + def path_score(self, node_id: int) -> float: + """ + The function calculates the score of a given path from the node with node_id to the root node. + + :param node_id: The ID of a given node + :type node_id: int + """ + + cumulated_nodes_value, path_length = 0, 0 + while node_id: + path_length += 1 + + cumulated_nodes_value += self.nodes_total_value[node_id] + node_id = self.parents[node_id] + + return cumulated_nodes_value / (path_length ** 2) + + def path_to_node(self, node_id: int) -> list: + """ + The function returns the path (list of IDs of nodes) to from a node specified by node_id to the root node. + + :param node_id: The ID of a given node + :type node_id: int + """ + + nodes = [] + while node_id: + nodes.append(node_id) + node_id = self.parents[node_id] + return [self.nodes[node_id] for node_id in reversed(nodes)] + + def synthesis_path(self, node_id: int) -> Tuple[Reaction, ...]: + """ + Given a node_id, return a tuple of Reactions that represent the synthesis path from the + node specified with node_id to the root node + + :param node_id: The ID of a given node + :type node_id: int + """ + + nodes = self.path_to_node(node_id) + + tmp = [ + Reaction( + [x.molecule for x in after.new_retrons], + [before.curr_retron.molecule], + ) + for before, after in zip(nodes, nodes[1:]) + ] + + for r in tmp: + r.clean2d() + return tuple(reversed(tmp)) + + def newickify(self, visits_threshold: int = 0, root_node_id: int = 1): # TODO what is return here ? + """ + Adopted from https://stackoverflow.com/questions/50003007/how-to-convert-python-dictionary-to-newick-form-format + :param visits_threshold: the minimum number of visits for the given node # TODO is this explanation correct ? + :type visits_threshold: int + :param root_node_id: The ID of a root node + :type root_node_id: int + """ + visited_nodes = set() + + def newick_render_node(current_node_id: int) -> str: + """ + Recursively generates a Newick string representation of a tree + + :param current_node_id: The identifier of the current node in the tree + :type current_node_id: The identifier of the current node in the tree + :return: A string representation of a node in a Newick format + """ + assert ( + current_node_id not in visited_nodes + ), "Error: The tree may not be circular!" + node_visit = self.nodes_visit[current_node_id] + + visited_nodes.add(current_node_id) + if self.children[current_node_id]: + # Nodes + children = [ + child + for child in list(self.children[current_node_id]) + if self.nodes_visit[child] >= visits_threshold + ] + children_strings = [newick_render_node(child) for child in children] + children_strings = ",".join(children_strings) + if children_strings: + return f"({children_strings}){current_node_id}:{node_visit}" + else: + # Leafs within threshold + return f"{current_node_id}:{node_visit}" + else: + # Leafs + return f"{current_node_id}:{node_visit}" + + newick_string = newick_render_node(root_node_id) + ";" + + meta = {} + for node_id in iter(visited_nodes): + node_value = round(self.nodes_total_value[node_id], 3) + + node_synthesisability = round(self.nodes_init_value[node_id]) + + visit_in_node = self.nodes_visit[node_id] + meta[node_id] = (node_value, node_synthesisability, visit_in_node) + + return newick_string, meta diff --git a/SynTool/ml/__init__.py b/SynTool/ml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SynTool/ml/__pycache__/__init__.cpython-310.pyc b/SynTool/ml/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86b6a3c2e377e33082d8f1eb0d402cc860cc83c1 Binary files /dev/null and b/SynTool/ml/__pycache__/__init__.cpython-310.pyc differ diff --git a/SynTool/ml/networks/__init__.py b/SynTool/ml/networks/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SynTool/ml/networks/__pycache__/__init__.cpython-310.pyc b/SynTool/ml/networks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fbc0740fc7f505f1f3c8cc818e5a755d7b0b059 Binary files /dev/null and b/SynTool/ml/networks/__pycache__/__init__.cpython-310.pyc differ diff --git a/SynTool/ml/networks/__pycache__/modules.cpython-310.pyc b/SynTool/ml/networks/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5033f0723ba06ccb8216c61cabb7cd0eff56cbd5 Binary files /dev/null and b/SynTool/ml/networks/__pycache__/modules.cpython-310.pyc differ diff --git a/SynTool/ml/networks/__pycache__/policy.cpython-310.pyc b/SynTool/ml/networks/__pycache__/policy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e94c9bc256a5e0de68c1afb5d52668c014f54cd3 Binary files /dev/null and b/SynTool/ml/networks/__pycache__/policy.cpython-310.pyc differ diff --git a/SynTool/ml/networks/__pycache__/value.cpython-310.pyc b/SynTool/ml/networks/__pycache__/value.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88b7ef3e5031bf963561e71913962a5fcbb4c7df Binary files /dev/null and b/SynTool/ml/networks/__pycache__/value.cpython-310.pyc differ diff --git a/SynTool/ml/networks/modules.py b/SynTool/ml/networks/modules.py new file mode 100755 index 0000000000000000000000000000000000000000..daee812b750c6bfbbe4d20b31cdc61c9be548fde --- /dev/null +++ b/SynTool/ml/networks/modules.py @@ -0,0 +1,188 @@ +""" +Module containing classes pytorch architectures of policy and value neural networks +""" + +from abc import ABC, abstractmethod + +import torch +from adabelief_pytorch import AdaBelief +from pytorch_lightning import LightningModule +from torch.nn import Linear, Module, Dropout, ModuleList, GELU, LayerNorm, ModuleDict +from torch.nn.functional import relu +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch_geometric.nn.conv import GCNConv +from torch_geometric.nn.pool import global_add_pool + + +class GraphEmbedding(Module): + """ + Needed to convert molecule atom vectors to the single vector using graph convolution + """ + + def __init__(self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 5): + """ + It initializes a graph convolutional module. Needed to convert molecule atom vectors to the single vector + using graph convolution. + + :param vector_dim: The dimensionality of the hidden layers and output layer of graph convolution module. + :type vector_dim: int + :param dropout: Dropout is a regularization technique used in neural networks to prevent overfitting. + It randomly sets a fraction of input units to 0 at each update during training time. + :type dropout: float + :param num_conv_layers: The number of convolutional layers in a graph convolutional module. + :type num_conv_layers: int + """ + + super(GraphEmbedding, self).__init__() + self.expansion = Linear(11, vector_dim) + self.dropout = Dropout(dropout) + self.gcn_convs = ModuleList([GCNConv(vector_dim, vector_dim, improved=True, ) for _ in range(num_conv_layers)]) + + def forward(self, graph, batch_size): + """ + The forward function takes a graph as input and performs graph convolution on it. + + :param batch_size: + :param graph: The molecular graph, where each atom is represented by the atom/bond vector + """ + atoms, connections = graph.x.float(), graph.edge_index.long() + atoms = torch.log(atoms + 1) + atoms = self.expansion(atoms) + for gcn_conv in self.gcn_convs: + atoms = atoms + self.dropout(relu(gcn_conv(atoms, connections))) + + return global_add_pool(atoms, graph.batch, size=batch_size) + + +class GraphEmbeddingConcat(GraphEmbedding, Module): + def __init__(self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 8): + super(GraphEmbeddingConcat, self).__init__() + + gcn_dim = vector_dim // num_conv_layers + + self.expansion = Linear(11, gcn_dim) + self.dropout = Dropout(dropout) + self.gcn_convs = ModuleList( + [ + ModuleDict( + { + "gcn": GCNConv(gcn_dim, gcn_dim, improved=True), + "activation": GELU(), + # "norm": LayerNorm(gcn_dim) + } + ) + for _ in range(num_conv_layers) + ] + ) + + def forward(self, graph, batch_size): + atoms, connections = graph.x.float(), graph.edge_index.long() + atoms = torch.log(atoms + 1) + atoms = self.expansion(atoms) + + collected_atoms = [] + for gcn_convs in self.gcn_convs: + atoms = gcn_convs["gcn"](atoms, connections) + atoms = gcn_convs["activation"](atoms) + # atoms = gcn_convs["norm"](atoms) + atoms = self.dropout(atoms) + collected_atoms.append(atoms) + + atoms = torch.cat(collected_atoms, dim=-1) + + return global_add_pool(atoms, graph.batch, size=batch_size) + + +class MCTSNetwork(LightningModule, ABC): + """ + Basic class for policy and value networks + """ + + def __init__(self, vector_dim, batch_size, dropout=0.4, num_conv_layers=5, learning_rate=0.001, gcn_concat=False): + """ + The basic class for MCTS graph convolutional neural networks (policy and value network). + + :param vector_dim: The dimensionality of the hidden layers and output layer of graph convolution module. + :type vector_dim: int + :param dropout: Dropout is a regularization technique used in neural networks to prevent overfitting. + It randomly sets a fraction of input units to 0 at each update during training time. + :type dropout: float + :param num_conv_layers: The number of convolutional layers in a graph convolutional module. + :type num_conv_layers: int + :param learning_rate: The learning rate determines how quickly the model learns from the training data. + :type learning_rate: float + """ + super(MCTSNetwork, self).__init__() + if gcn_concat: + self.embedder = GraphEmbeddingConcat(vector_dim, dropout, num_conv_layers) + else: + self.embedder = GraphEmbedding(vector_dim, dropout, num_conv_layers) + self.batch_size = batch_size + self.lr = learning_rate + + @abstractmethod + def forward(self, batch): + """ + The forward function takes a batch of input data and performs forward propagation through the neural network. + + :param batch: The batch parameter is a collection of input data that is processed together in a single forward + pass through the neural network. + """ + ... + + @abstractmethod + def _get_loss(self, batch): + """ + This function is used to calculate the loss for a given batch of data. + + :param batch: The batch parameter is a batch of input data that is used to compute the loss. + """ + ... + + def training_step(self, batch, batch_idx): + """ + Calculates the loss for a given training batch and logs the loss value. + + :param batch: The batch of data that is used for training. + :param batch_idx: The index of the batch. + :return: the value of the training loss. + """ + metrics = self._get_loss(batch) + for name, value in metrics.items(): + self.log('train_' + name, value, prog_bar=True, on_step=True, on_epoch=True, batch_size=self.batch_size) + return metrics['loss'] + + def validation_step(self, batch, batch_idx): + """ + Calculates the loss for a given validation batch and logs the loss value. + + :param batch: The batch of data that is used for validation. + :param batch_idx: The index of the batch. + """ + metrics = self._get_loss(batch) + for name, value in metrics.items(): + self.log('val_' + name, value, on_epoch=True, batch_size=self.batch_size) + + def test_step(self, batch, batch_idx): + """ + Calculates the loss for a given test batch and logs the loss value. + + :param batch: The batch of data that is used for testing. + :param batch_idx: The index of the batch. + """ + metrics = self._get_loss(batch) + for name, value in metrics.items(): + self.log('test_' + name, value, on_epoch=True, batch_size=self.batch_size) + + def configure_optimizers(self): + """ + Returns an optimizer and a learning rate scheduler for training a model using the AdaBelief optimizer + and ReduceLROnPlateau scheduler. + :return: The optimizer and a scheduler. + """ + optimizer = AdaBelief(self.parameters(), lr=self.lr, eps=1e-16, betas=(0.9, 0.999), weight_decouple=True, + rectify=True, weight_decay=0.01, print_change_log=False) + + lr_scheduler = ReduceLROnPlateau(optimizer, patience=3, factor=0.8, min_lr=5e-5, verbose=True) + scheduler = {'scheduler': lr_scheduler, 'reduce_on_plateau': True, 'monitor': 'val_loss'} + return [optimizer], [scheduler] diff --git a/SynTool/ml/networks/policy.py b/SynTool/ml/networks/policy.py new file mode 100644 index 0000000000000000000000000000000000000000..f46cb96509f6d501f332b157dfed7200b026a754 --- /dev/null +++ b/SynTool/ml/networks/policy.py @@ -0,0 +1,110 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Dict, Any + +import yaml +import torch +from pytorch_lightning import LightningModule +from torch.nn import Linear +from torch.nn.functional import binary_cross_entropy_with_logits, cross_entropy, one_hot +from torchmetrics.functional.classification import recall, specificity, f1_score + +from SynTool.ml.networks.modules import MCTSNetwork + + +class PolicyNetwork(MCTSNetwork, LightningModule, ABC): + """ + Policy value network + """ + + def __init__(self, n_rules, vector_dim, policy_type="filtering", *args, **kwargs): + """ + Initializes a policy network with the given number of reaction rules (output dimension) and vector graph + embedding dimension, and creates linear layers for predicting the regular and priority reaction rules. + + :param n_rules: The number of reaction rules in the policy network. + :param vector_dim: The dimensionality of the input vectors. + """ + super(PolicyNetwork, self).__init__(vector_dim, *args, **kwargs) + self.save_hyperparameters() + self.policy_type = policy_type + self.n_rules = n_rules + self.y_predictor = Linear(vector_dim, n_rules) + if self.policy_type == "filtering": + self.priority_predictor = Linear(vector_dim, n_rules) + + def forward(self, batch): + """ + The forward function takes a molecular graph, applies a graph convolution and sigmoid layers to predict + regular and priority reaction rules. + + :param batch: The input batch of molecular graphs. + :return: Returns the vector of probabilities (given by sigmoid) of successful application of regular and + priority reaction rules. + """ + x = self.embedder(batch, self.batch_size) + y = self.y_predictor(x) + if self.policy_type == "filtering": + y = torch.sigmoid(y) + priority = torch.sigmoid(self.priority_predictor(x)) + return y, priority + elif self.policy_type == "ranking": + y = torch.softmax(y, dim=-1) + return y + + def _get_loss(self, batch): + """ + Calculates the loss and various classification metrics for a given batch for reaction rules prediction. + + :param batch: The batch of molecular graphs. + :return: a dictionary with loss value and balanced accuracy of reaction rules prediction. + """ + true_y = batch.y_rules.long() + x = self.embedder(batch, self.batch_size) + pred_y = self.y_predictor(x) + + if self.policy_type == "ranking": + true_one_hot = one_hot(true_y, num_classes=self.n_rules) + loss = cross_entropy(pred_y, true_one_hot.float()) + ba_y = ( + recall(pred_y, true_y, task="multiclass", num_classes=self.n_rules) + + specificity(pred_y, true_y, task="multiclass", num_classes=self.n_rules) + ) / 2 + f1_y = f1_score(pred_y, true_y, task="multiclass", num_classes=self.n_rules) + metrics = { + 'loss': loss, + 'balanced_accuracy_y': ba_y, + 'f1_score_y': f1_y + } + elif self.policy_type == "filtering": + loss_y = binary_cross_entropy_with_logits(pred_y, true_y.float()) + ba_y = ( + recall(pred_y, true_y, task="multilabel", num_labels=self.n_rules) + + specificity(pred_y, true_y, task="multilabel", num_labels=self.n_rules) + ) / 2 + f1_y = f1_score(pred_y, true_y, task="multilabel", num_labels=self.n_rules) + + true_priority = batch.y_priority.float() + pred_priority = self.priority_predictor(x) + + loss_priority = binary_cross_entropy_with_logits(pred_priority, true_priority) + loss = loss_y + loss_priority + + true_priority = true_priority.long() + + ba_priority = ( + recall(pred_priority, true_priority, task="multilabel", num_labels=self.n_rules) + + specificity(pred_priority, true_priority, task="multilabel", num_labels=self.n_rules) + ) / 2 + f1_priority = f1_score(pred_priority, true_priority, task="multilabel", num_labels=self.n_rules) + metrics = { + 'loss': loss, + 'balanced_accuracy_y': ba_y, + 'f1_score_y': f1_y, + 'balanced_accuracy_priority': ba_priority, + 'f1_score_priority': f1_priority + } + else: + raise ValueError(f"Invalid mode: {self.policy_type}") + + return metrics diff --git a/SynTool/ml/networks/value.py b/SynTool/ml/networks/value.py new file mode 100644 index 0000000000000000000000000000000000000000..a8419430de5c5835243e493e73466cb087e9869b --- /dev/null +++ b/SynTool/ml/networks/value.py @@ -0,0 +1,57 @@ +from abc import ABC + +import torch +from pytorch_lightning import LightningModule +from torch.nn import Linear +from torch.nn.functional import binary_cross_entropy_with_logits +from torchmetrics.functional.classification import binary_recall, binary_specificity, binary_f1_score + +from SynTool.ml.networks.modules import MCTSNetwork + + +class ValueNetwork(MCTSNetwork, LightningModule, ABC): + """ + Value value network + """ + + def __init__(self, vector_dim, *args, **kwargs): + """ + Initializes a value network, and creates linear layer for predicting the synthesisability of given retron + represented by molecular graph. + + :param vector_dim: The dimensionality of the output linear layer. + """ + super(ValueNetwork, self).__init__(vector_dim, *args, **kwargs) + self.save_hyperparameters() + self.predictor = Linear(vector_dim, 1) + + def forward(self, batch) -> torch.Tensor: + """ + The forward function takes a batch of molecular graphs, applies a graph convolution returns the synthesisability + (probability given by sigmoid function) of a given retron represented by molecular graph precessed by + graph convolution. + + :param batch: The batch of molecular graphs. + :return: a predicted synthesisability (between 0 and 1). + """ + x = self.embedder(batch, self.batch_size) + x = torch.sigmoid(self.predictor(x)) + return x + + def _get_loss(self, batch): + """ + Calculates the loss and various classification metrics for a given batch for retron synthesysability prediction. + + :param batch: The batch of molecular graphs. + :return: a dictionary with loss value and balanced accuracy of retron synthesysability prediction. + """ + true_y = batch.y.float() + true_y = torch.unsqueeze(true_y, -1) + x = self.embedder(batch, self.batch_size) + pred_y = self.predictor(x) + loss = binary_cross_entropy_with_logits(pred_y, true_y) + true_y = true_y.long() + ba = (binary_recall(pred_y, true_y) + binary_specificity(pred_y, true_y)) / 2 + f1 = binary_f1_score(pred_y, true_y) + metrics = {'loss': loss, 'balanced_accuracy': ba, 'f1_score': f1} + return metrics diff --git a/SynTool/ml/training/__init__.py b/SynTool/ml/training/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..483f260f2058b11dc071d7717451ed05e9a90c0b --- /dev/null +++ b/SynTool/ml/training/__init__.py @@ -0,0 +1,11 @@ +from .supervised import * +from .preprocessing import ValueNetworkDataset, mol_to_pyg, MENDEL_INFO +from .supervised import create_policy_dataset, run_policy_training + +__all__ = [ + "ValueNetworkDataset", + "mol_to_pyg", + "MENDEL_INFO", + 'create_policy_dataset', + 'run_policy_training' +] diff --git a/SynTool/ml/training/__pycache__/__init__.cpython-310.pyc b/SynTool/ml/training/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88e857732184d926c46e4a45155b6031e28bac00 Binary files /dev/null and b/SynTool/ml/training/__pycache__/__init__.cpython-310.pyc differ diff --git a/SynTool/ml/training/__pycache__/preprocessing.cpython-310.pyc b/SynTool/ml/training/__pycache__/preprocessing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be4467af4767be494f3127ad70e1180084551148 Binary files /dev/null and b/SynTool/ml/training/__pycache__/preprocessing.cpython-310.pyc differ diff --git a/SynTool/ml/training/__pycache__/reinforcement.cpython-310.pyc b/SynTool/ml/training/__pycache__/reinforcement.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c9b3244c15794a83730756a72d188db41b4c741 Binary files /dev/null and b/SynTool/ml/training/__pycache__/reinforcement.cpython-310.pyc differ diff --git a/SynTool/ml/training/__pycache__/supervised.cpython-310.pyc b/SynTool/ml/training/__pycache__/supervised.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33eb092bca3a768b36e533eac88f56b6598fe96b Binary files /dev/null and b/SynTool/ml/training/__pycache__/supervised.cpython-310.pyc differ diff --git a/SynTool/ml/training/preprocessing.py b/SynTool/ml/training/preprocessing.py new file mode 100755 index 0000000000000000000000000000000000000000..1b5bcd66941c7b0dd2da3bcb83911cd41d348bb7 --- /dev/null +++ b/SynTool/ml/training/preprocessing.py @@ -0,0 +1,513 @@ +""" +Module containing functions for preparation of the training sets for policy and value network +""" + +import os +import pickle +from abc import ABC +from multiprocessing import Manager, Pool +from pathlib import Path +from typing import List + +import ray +import torch +from CGRtools import smiles +from CGRtools.containers import MoleculeContainer +from CGRtools.exceptions import InvalidAromaticRing +from CGRtools.reactor import Reactor +from ray.util.queue import Queue, Empty +from torch_geometric.data import Data, InMemoryDataset +from torch_geometric.data.makedirs import makedirs +from torch_geometric.transforms import ToUndirected +from tqdm import tqdm + +from SynTool.utils.loading import load_reaction_rules +from SynTool.chem.utils import unite_molecules +from SynTool.utils.files import ReactionReader + + +class ValueNetworkDataset(InMemoryDataset, ABC): + """ + Value network dataset + """ + + def __init__(self, extracted_retrons): + """ + Initializes a value network dataset object. + + :param extracted_retrons: The path to a file containing processed molecules (retrons) extracted from + search tree. + """ + super().__init__(None, None, None) + + if extracted_retrons: + self.data, self.slices = self.prepare_from_extracted_retrons(extracted_retrons) + + def prepare_pyg(self, molecule, label): + """ + It takes a molecule as input, and converts the molecule to a PyTorch geometric graph, + assigns the reward value (label) to the graph, and returns the graph. + + :param molecule: The molecule object that represents a chemical compound. + :return: a PyTorch Geometric (PyG) graph representation of a molecule. If the molecule has a "label" key in its + metadata, the function sets the reward variable to the value of the "label" key converted to a float. Otherwise, + the reward variable is set to 0. + """ + if len(molecule) > 2: + pyg = mol_to_pyg(molecule) + if pyg: + pyg.y = torch.tensor([label]) + return pyg + else: + return None + + def prepare_from_extracted_retrons(self, extracted_retrons): + """ + The function prepares processed data from a given file path by reading SMILES data, converting it to + PyTorch geometric graph format, and returning the processed data and slices. + + :param extracted_retrons: The path to a file containing processed molecules. It is assumed that the file + is in a format that can be read by the SMILESRead class, and that it has a header row with a column + named "label" + :return: data (PyTorch geometric graphs) and slices. + """ + processed_data = [] + for smi, label in extracted_retrons.items(): + mol = smiles(smi) + pyg = self.prepare_pyg(mol, label) + if pyg: + processed_data.append(pyg) + data, slices = self.collate(processed_data) + return data, slices + + +class RankingPolicyDataset(InMemoryDataset): + """ + Policy network dataset + """ + + def __init__(self, reactions_path, reaction_rules_path, output_path): + """ + Initializes a policy network dataset object. + + :param reactions_path: The path to the file containing the reaction data used for extraction of reaction rules. + :param reaction_rules_path: The path to the file containing the reaction rules. + :param output_path: The output path is the location where policy network dataset will be stored. + """ + super().__init__(None, None, None) + + self.reactions_path = Path(reactions_path).resolve(strict=True) + self.reaction_rules_path = Path(reaction_rules_path).resolve(strict=True) + self.output_path = output_path + + if output_path and os.path.exists(output_path): + self.data, self.slices = torch.load(self.output_path) + else: + self.data, self.slices = self.prepare_data() + + @property + def num_classes(self) -> int: + return self._infer_num_classes(self._data.y_rules) + + def prepare_data(self): + """ + The function prepares data by loading reaction rules, initializing Ray, preprocessing the molecules, collating + the data, and returning the data and slices. + :return: data (PyTorch geometric graphs) and slices. + """ + + with open(self.reaction_rules_path, "rb") as inp: + reaction_rules = pickle.load(inp) + + dataset = {} + for rule_i, (_, reactions_ids) in enumerate(reaction_rules): + for reaction_id in reactions_ids: + dataset[reaction_id] = rule_i + dataset = dict(sorted(dataset.items())) + + list_of_graphs = [] + with ReactionReader(self.reactions_path) as reactions: + + for reaction_id, reaction in tqdm(enumerate(reactions)): + + rule_id = dataset.get(reaction_id) + if rule_id: + try: # TODO force solution <= MENDEL INFO doesnt have cadmium prop (Cd) + molecule = unite_molecules(reaction.products) + pyg_graph = mol_to_pyg(molecule) + except KeyError: + continue + + if pyg_graph is not None: + pyg_graph.y_rules = torch.tensor([rule_id], dtype=torch.long) + list_of_graphs.append(pyg_graph) + else: + continue + + data, slices = self.collate(list_of_graphs) + if self.output_path: + makedirs(os.path.dirname(self.output_path)) + torch.save((data, slices), self.output_path) + + return data, slices + + +class FilteringPolicyDataset(InMemoryDataset): + """ + Policy network dataset + """ + + def __init__(self, molecules_path, reaction_rules_path, output_path, num_cpus=1): + """ + Initializes a policy network dataset object. + + :param molecules_path: The path to the file containing the molecules data + :param reaction_rules_path: The path to the file containing the reaction rules. + :param output_path: The output path is the location where policy network dataset will be stored. + :param num_cpus: Specifies the number of CPUs to be used for the data preparation. + """ + super().__init__(None, None, None) + + self.molecules_path = molecules_path + self.reaction_rules_path = reaction_rules_path + self.output_path = output_path + self.num_cpus = num_cpus + self.batch_size = 10 + + if output_path and os.path.exists(output_path): + self.data, self.slices = torch.load(self.output_path) + else: + self.data, self.slices = self.prepare_data() + + @property + def num_classes(self) -> int: + return self._data.y_rules.shape[1] + + def prepare_data(self): + """ + The function prepares data by loading reaction rules, initializing Ray, preprocessing the molecules, collating + the data, and returning the data and slices. + :return: data (PyTorch geometric graphs) and slices. + """ + + ray.init(num_cpus=self.num_cpus, ignore_reinit_error=True) + reaction_rules = load_reaction_rules(self.reaction_rules_path) + reaction_rules_ids = ray.put(reaction_rules) + + to_process = Queue(maxsize=self.batch_size * self.num_cpus) + processed_data = [] + results_ids = [preprocess_filtering_policy_molecules.remote(to_process, reaction_rules_ids) for _ in range(self.num_cpus)] + + with open(self.molecules_path, "r") as inp_data: + for molecule in tqdm(inp_data.read().splitlines()): + to_process.put(molecule) + + results = [graph for res in ray.get(results_ids) if res for graph in res] + processed_data.extend(results) + + ray.shutdown() + + for pyg in processed_data: + pyg.y_rules = pyg.y_rules.to_dense() + pyg.y_priority = pyg.y_priority.to_dense() + + data, slices = self.collate(processed_data) + if self.output_path: + makedirs(os.path.dirname(self.output_path)) + torch.save((data, slices), self.output_path) + + return data, slices + + def prepare_data_no_ray(self): + ####### + # /!\ Possible alternatives to ray, has to be checked once pytorch updated + ####### + + global reaction_rules + reaction_rules = load_reaction_rules(self.reaction_rules_path) + + with Manager() as m, Pool() as p: + to_process = m.Queue() + + processed_data = [] + # print(f'{len(mols_batches)} batches were created with {len(mols_batches[0])} molecules each') + mols_batch = [] + with open(self.molecules_path, "r") as inp_data: + for molecule in tqdm(inp_data.read().splitlines()): + mols_batch.append(molecule) + if len(mols_batch) == self.batch_size: # * self.num_cpus: + for mol in mols_batch: + to_process.put(mol) + mols_batch = [] + workers_results = m.list() + + workers = [p.apply_async(preprocess_filtering_policy_molecules, (to_process, workers_results)) for _ in range(40)] + print([res.get() for res in workers]) + # # for i in range(40): + # # w = Process(target=preprocess_policy_molecules, args=(to_process, workers_results)) + # # w.start() + # # workers.append(w) + # # for w in workers: + # # w.join() + # results = [graph for res in ray.get(results_ids) if res for graph in res] + # processed_data.extend(list(workers_results)) + # mols_batch = [] + + for pyg in processed_data: + pyg.y_rules = pyg.y_rules.to_dense() + pyg.y_priority = pyg.y_priority.to_dense() + + data, slices = self.collate(processed_data) + if self.output_path: + makedirs(os.path.dirname(self.output_path)) + torch.save((data, slices), self.output_path) + + return data, slices + + +def reaction_rules_appliance(molecule, reaction_rules): + """ + The function applies each rule from the list of reaction rules to a given molecule and returns the indexes of + the successfully applied regular rules and the indexes of the prioritized rules. + + :param molecule: The given molecule + :param reaction_rules: The list of reaction rules + :return: two lists: indexes of successfully applied regular and priority reaction rules. + """ + + applied_rules, priority_rules = [], [] + for i, rule in enumerate(reaction_rules): + + rule_applied = False + rule_prioritized = False + + try: + tmp = [molecule.copy()] + for reaction in rule(tmp): + for prod in reaction.products: + + prod.kekule() + if prod.check_valence(): + break + else: + rule_applied = True + + # check priority rules + if len(reaction.products) > 1: + # check coupling retro manual + if all(len(mol) > 6 for mol in reaction.products): + if sum(len(mol) for mol in reaction.products) - len(reaction.reactants[0]) < 6: + rule_prioritized = True + else: + # check cyclization retro manual + if sum(len(mol.sssr) for mol in reaction.products) < sum( + len(mol.sssr) for mol in reaction.reactants): + rule_prioritized = True + # + if rule_applied: + applied_rules.append(i) + # + if rule_prioritized: + priority_rules.append(i) + + except: + continue + + return applied_rules, priority_rules + + +@ray.remote +def preprocess_filtering_policy_molecules(to_process: Queue, reaction_rules: List[Reactor]): + """ + The function preprocesses a list of molecules by applying reaction rules and converting molecules into PyTorch + geometric graphs. Successfully applied rules are converted to binary vectors for policy network training. + + :param to_process: The queue containing SMILES of molecules to be converted to the training data. + :type to_process: Queue + :param reaction_rules: The list of reaction rules. + :type reaction_rules: List[Reactor] + :return: a list of PyGraph objects. + """ + + pyg_graphs = [] + while True: + try: + molecule = smiles(to_process.get(timeout=30)) + if not isinstance(molecule, MoleculeContainer): + continue + + # reaction reaction_rules application + applied_rules, priority_rules = reaction_rules_appliance(molecule, reaction_rules) + y_rules = torch.sparse_coo_tensor([applied_rules], torch.ones(len(applied_rules)), + (len(reaction_rules),), dtype=torch.uint8) + y_priority = torch.sparse_coo_tensor([priority_rules], torch.ones(len(priority_rules)), + (len(reaction_rules),), dtype=torch.uint8) + + y_rules = torch.unsqueeze(y_rules, 0) + y_priority = torch.unsqueeze(y_priority, 0) + + pyg_graph = mol_to_pyg(molecule) + if not pyg_graph: + continue + pyg_graph.y_rules = y_rules + pyg_graph.y_priority = y_priority + pyg_graphs.append(pyg_graph) + except Empty: + break + return pyg_graphs + + +def preprocess_policy_molecules_no_ray(to_process: Queue, workers_results): + ####### + # /!\ Possible alternatives to ray, has to be checked once pytorch updated + ####### + + pyg_graphs = [] + while True: + try: + molecule_str = to_process.get(timeout=1) + molecule = smiles(molecule_str) + if not isinstance(molecule, MoleculeContainer): + continue + + # reaction reaction_rules application + applied_rules, priority_rules = reaction_rules_appliance(molecule, reaction_rules) + y_rules = torch.sparse_coo_tensor([applied_rules], torch.ones(len(applied_rules)), (len(reaction_rules),), + dtype=torch.uint8) + y_priority = torch.sparse_coo_tensor([priority_rules], torch.ones(len(priority_rules)), + (len(reaction_rules),), dtype=torch.uint8) + + y_rules = torch.unsqueeze(y_rules, 0) + y_priority = torch.unsqueeze(y_priority, 0) + + pyg_graph = mol_to_pyg(molecule) + if pyg_graph: + pyg_graph.y_rules = y_rules + pyg_graph.y_priority = y_priority + else: + continue + + pyg_graphs.append(pyg_graph) + except Empty: + break + workers_results.extend(pyg_graphs) #[pyg for pyg in pyg_graphs if pyg]) + return pyg_graphs + + +def atom_to_vector(atom): + """ + Given an atom, return a vector of length 8 with the following information: + + 1. Atomic number + 2. Period + 3. Group + 4. Number of electrons + atom's charge + 5. Shell + 6. Total number of hydrogens + 7. Whether the atom is in a ring + 8. Number of neighbors + + :param atom: the atom object + :return: The vector of the atom. + """ + vector = torch.zeros(8, dtype=torch.uint8) + period, group, shell, electrons = MENDEL_INFO[atom.atomic_symbol] + vector[0] = atom.atomic_number + vector[1] = period + vector[2] = group + vector[3] = electrons + atom.charge + vector[4] = shell + vector[5] = atom.total_hydrogens + vector[6] = int(atom.in_ring) + vector[7] = atom.neighbors + return vector + + +def bonds_to_vector(molecule: MoleculeContainer, atom_ind: int): + """ + The function takes a molecule and an atom index as input, and returns a vector representing the bond + orders of the atom's bonds. + + :param molecule: The given molecule + :type molecule: MoleculeContainer + :param atom_ind: The index of the atom in the molecule for which we want to calculate the bond vector. + :type atom_ind: int + :return: a torch tensor of size 3, with each element representing the order of bonds connected to the atom + with the given index in the molecule. + """ + vector = torch.zeros(3, dtype=torch.uint8) + for b_order in molecule._bonds[atom_ind].values(): + vector[int(b_order) - 1] += 1 + return vector + + +def mol_to_matrix(molecule: MoleculeContainer): + """ + Given a target, it returns a vector of shape (max_atoms, 12) where each row is an atom and each + column is a feature. + + :param molecule: The target to be converted to a vector + :type molecule: MoleculeContainer + :return: The atoms_vectors array + """ + + atoms_vectors = torch.zeros((len(molecule), 11), dtype=torch.uint8) + for n, atom in molecule.atoms(): + atoms_vectors[n - 1][:8] = atom_to_vector(atom) + for n, _ in molecule.atoms(): + atoms_vectors[n - 1][8:] = bonds_to_vector(molecule, n) + + return atoms_vectors + + +def mol_to_pyg(molecule: MoleculeContainer, canonicalize=True): + """ + It takes a list of molecules and returns a list of PyTorch Geometric graphs, + a one-hot encoded vectors of the atoms, and a matrices of the bonds. + + :param canonicalize: + :param molecule: The molecule to be converted to PyTorch Geometric graph. + :return: A list of pyg graphs + """ + tmp_molecule = molecule.copy() + try: + if canonicalize: + tmp_molecule.canonicalize() + tmp_molecule.kekule() + if tmp_molecule.check_valence(): + return None + except InvalidAromaticRing: + return None + + # remapping target for torch_geometric because + # it is necessary that the elements in edge_index only hold nodes_idx in the range { 0, ..., num_nodes - 1} + new_mappings = {n: i for i, (n, _) in enumerate(tmp_molecule.atoms(), 1)} + tmp_molecule.remap(new_mappings) + + # get edge indexes from target mapping + edge_index = [] + for atom, neighbour, bond in tmp_molecule.bonds(): + edge_index.append([atom - 1, neighbour - 1]) + edge_index = torch.tensor(edge_index, dtype=torch.long) + + # + x = mol_to_matrix(tmp_molecule) + + mol_pyg_graph = Data(x=x, edge_index=edge_index.t().contiguous()) + mol_pyg_graph = ToUndirected()(mol_pyg_graph) + + assert mol_pyg_graph.is_undirected() + return mol_pyg_graph + + +MENDEL_INFO = {"Ag": (5, 11, 1, 1), "Al": (3, 13, 2, 1), "Ar": (3, 18, 2, 6), "As": (4, 15, 2, 3), "B": (2, 13, 2, 1), + "Ba": (6, 2, 1, 2), "Bi": (6, 15, 2, 3), "Br": (4, 17, 2, 5), "C": (2, 14, 2, 2), "Ca": (4, 2, 1, 2), + "Ce": (6, None, 1, 2), "Cl": (3, 17, 2, 5), "Cr": (4, 6, 1, 1), "Cs": (6, 1, 1, 1), "Cu": (4, 11, 1, 1), + "Dy": (6, None, 1, 2), "Er": (6, None, 1, 2), "F": (2, 17, 2, 5), "Fe": (4, 8, 1, 2), "Ga": (4, 13, 2, 1), + "Gd": (6, None, 1, 2), "Ge": (4, 14, 2, 2), "Hg": (6, 12, 1, 2), "I": (5, 17, 2, 5), "In": (5, 13, 2, 1), + "K": (4, 1, 1, 1), "La": (6, 3, 1, 2), "Li": (2, 1, 1, 1), "Mg": (3, 2, 1, 2), "Mn": (4, 7, 1, 2), + "N": (2, 15, 2, 3), "Na": (3, 1, 1, 1), "Nd": (6, None, 1, 2), "O": (2, 16, 2, 4), "P": (3, 15, 2, 3), + "Pb": (6, 14, 2, 2), "Pd": (5, 10, 3, 10), "Pr": (6, None, 1, 2), "Rb": (5, 1, 1, 1), "S": (3, 16, 2, 4), + "Sb": (5, 15, 2, 3), "Se": (4, 16, 2, 4), "Si": (3, 14, 2, 2), "Sm": (6, None, 1, 2), "Sn": (5, 14, 2, 2), + "Sr": (5, 2, 1, 2), "Te": (5, 16, 2, 4), "Ti": (4, 4, 1, 2), "Tl": (6, 13, 2, 1), "Yb": (6, None, 1, 2), + "Zn": (4, 12, 1, 2)} diff --git a/SynTool/ml/training/reinforcement.py b/SynTool/ml/training/reinforcement.py new file mode 100755 index 0000000000000000000000000000000000000000..c65bc20d7435e860a7c455b58409e84515282d3b --- /dev/null +++ b/SynTool/ml/training/reinforcement.py @@ -0,0 +1,300 @@ +""" +Module containing functions for running value network tuning with self-tuning approach +""" + +import os.path +from collections import defaultdict +from pathlib import Path +from random import shuffle + +import torch +from CGRtools.containers import MoleculeContainer +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateMonitor +from torch.utils.data import random_split +from torch_geometric.data.lightning import LightningDataset +from tqdm import tqdm + +from SynTool.mcts.tree import Tree +from SynTool.utils.files import MoleculeReader +from SynTool.ml.training.preprocessing import ValueNetworkDataset +from SynTool.chem.retron import compose_retrons +from SynTool.utils.logging import DisableLogger, HiddenPrints +from SynTool.ml.networks.value import ValueNetwork +from SynTool.utils.loading import load_value_net +from SynTool.mcts.expansion import PolicyFunction +from SynTool.mcts.evaluation import ValueFunction +from SynTool.utils.config import TreeConfig, PolicyNetworkConfig, ValueNetworkConfig, ReinforcementConfig + + +def create_value_network(value_config): + + weights_path = Path(value_config.weights_path) + value_network = ValueNetwork(vector_dim=value_config.vector_dim, + batch_size=value_config.batch_size, + dropout=value_config.dropout, + num_conv_layers=value_config.num_conv_layers, + learning_rate=value_config.learning_rate) + + with DisableLogger() as DL, HiddenPrints() as HP: + trainer = Trainer() + trainer.strategy.connect(value_network) + trainer.save_checkpoint(weights_path) + + return value_network + + + +def create_targets_batch(targets, batch_size): + + num_targets = len(targets) + batch_splits = list(range(num_targets // batch_size + int(bool(num_targets % batch_size)))) + + if int(num_targets / batch_size) == 0: + print(f'1 batch were created with {num_targets} molecules') + else: + print(f'{len(batch_splits)} batches were created with {batch_size} molecules each') + + targets_batch_list = [] + for batch_id in batch_splits: + batch_slices = [i for i in range(batch_id * batch_size, (batch_id + 1) * batch_size) if i < len(targets)] + targets_batch_list.append([targets[i] for i in batch_slices]) + + return targets_batch_list + + + +def extract_tree_retrons(tree_list): + """ + Takes a built tree and a dictionary of processed molecules extracted from the previous trees as input, and returns + the updated dictionary of processed molecules after adding the solved nodes from the given tree. + + :param tree_list: The built tree + """ + extracted_retrons = defaultdict(float) + for tree in tree_list: + for idx, node in tree.nodes.items(): + # add solved nodes to set + if node.is_solved(): + parent = idx + while parent and parent != 1: + composed_smi = str(compose_retrons(tree.nodes[parent].new_retrons)) + extracted_retrons[composed_smi] = 1.0 + parent = tree.parents[parent] + else: + composed_smi = str(compose_retrons(tree.nodes[idx].new_retrons)) + extracted_retrons[composed_smi] = 0.0 + + # shuffle extracted retrons + processed_keys = list(extracted_retrons.keys()) + shuffle(processed_keys) + extracted_retrons = {i: extracted_retrons[i] for i in processed_keys} + + return extracted_retrons + + +def run_tree_search(target: MoleculeContainer, + tree_config: TreeConfig, + policy_config: PolicyNetworkConfig, + value_config: ValueNetworkConfig, + reaction_rules_path: str, + building_blocks_path: str): + """ + Takes a target molecule and a planning configuration dictionary as input, preprocesses the target molecule, + initializes a tree and then runs the tree search algorithm. + + :param target: The target molecule. It can be either a `MoleculeContainer` object or a SMILES string + :param tree_config: The planning configuration that contains settings for tree search + :return: The built tree + """ + + # policy and value function loading + # TODO solve this problem between network and policy config + policy_function = PolicyFunction(policy_config=policy_config) + value_function = ValueFunction(weights_path=value_config.weights_path) + + # initialize tree + tree_config.silent = True + tree = Tree(target=target, + tree_config=tree_config, + reaction_rules_path=reaction_rules_path, + building_blocks_path=building_blocks_path, + policy_function=policy_function, + value_function=value_function + ) + + # remove target from buildings blocs + if str(target) in tree.building_blocks: + tree.building_blocks.remove(str(target)) + + # run tree search + _ = list(tree) + + return tree + + +def create_tuning_set(extracted_retrons, batch_size=1): + """ + Creates a tuning dataset from a given processed molecules extracted from the trees from the + planning stage and returns a LightningDataset object with a specified batch size for tuning value neural network. + + :param batch_size: + :param extracted_retrons: The path to the directory where the processed molecules is stored + :return: A LightningDataset object, which contains the tuning sets for value network tuning + """ + + full_dataset = ValueNetworkDataset(extracted_retrons) + train_size = int(0.6 * len(full_dataset)) + val_size = len(full_dataset) - train_size + + train_set, val_set = random_split(full_dataset, [train_size, val_size], torch.Generator().manual_seed(42)) + + print(f"Training set size: {len(train_set)}") + print(f"Validation set size: {len(val_set)}") + + return LightningDataset(train_set, val_set, batch_size=batch_size, pin_memory=True, drop_last=True) + + +def tune_value_network(datamodule, value_config: ValueNetworkConfig): + """ + Trains a value network using a given data module and saves the trained neural network. + + :param datamodule: The instance of a PyTorch Lightning `DataModule` class with tuning set + :param value_config: + """ + + current_weights = value_config.weights_path + value_network = load_value_net(ValueNetwork, current_weights) + + lr_monitor = LearningRateMonitor(logging_interval="epoch") + + with DisableLogger() as DL, HiddenPrints() as HP: + trainer = Trainer(accelerator="gpu", + devices=[0], + max_epochs=value_config.num_epoch, + callbacks=[lr_monitor], + gradient_clip_val=1.0, + enable_progress_bar=False) + + trainer.fit(value_network, datamodule) + val_score = trainer.validate(value_network, datamodule.val_dataloader())[0] + trainer.save_checkpoint(current_weights) + # + print(f"Value network balanced accuracy: {val_score['val_balanced_accuracy']}") + + +def run_training(extracted_retrons=None, value_config=None): + + """ + Performs the training stage in self-tuning process. Trains a value network using a set of processed molecules and + saves the weights of the network. + + :param extracted_retrons: The path to the directory where the processed molecules extracted from planning + :param value_config: + """ + + # create training set + training_set = create_tuning_set(extracted_retrons=extracted_retrons, batch_size=value_config.batch_size) + + # retrain value network + tune_value_network(datamodule=training_set, value_config=value_config) + +def run_planning(targets_batch: list, + tree_config: TreeConfig, + policy_config: PolicyNetworkConfig, + value_config: ValueNetworkConfig, + reaction_rules_path: str, + building_blocks_path: str, + targets_batch_id: int): + + """ + Performs planning stage (tree search) for target molecules and save extracted from built trees retrons for further + tuning the value network in the training stage. + + :param targets_batch: + :param tree_config: + :param policy_config: + :param value_config: + :param reaction_rules_path: + :param building_blocks_path: + :param targets_batch_id: + """ + + print(f'\nProcess batch number {targets_batch_id}') + tree_list = [] + tree_config.silent = True + for target in tqdm(targets_batch): + + try: + tree = run_tree_search(target=target, + tree_config=tree_config, + policy_config=policy_config, + value_config=value_config, + reaction_rules_path=reaction_rules_path, + building_blocks_path=building_blocks_path) + tree_list.append(tree) + + except: + continue + + num_solved = sum([len(i.winning_nodes) > 0 for i in tree_list]) + print(f"Planning is finished with {num_solved} solved targets") + + return tree_list + + +def run_reinforcement_tuning(targets_path: str, + tree_config: TreeConfig, + policy_config: PolicyNetworkConfig, + value_config: ValueNetworkConfig, + reinforce_config: ReinforcementConfig, + reaction_rules_path: str, + building_blocks_path: str, + results_root=None): + """ + Performs self-tuning simulations with alternating planning and training stages + + :param targets_path: + :param tree_config: + :param policy_config: + :param value_config: + :param reinforce_config: + :param reaction_rules_path: + :param building_blocks_path: + :param results_root: + """ + + # create results root folder + results_root = Path(results_root) + if not results_root.exists(): + results_root.mkdir() + + # load targets list + with MoleculeReader(targets_path) as targets: + targets = list(targets) + + # create value neural network + value_config.weights_path = os.path.join(results_root, 'value_network.ckpt') + value_network = create_value_network(value_config) + + # create targets batch + targets_batch_list = create_targets_batch(targets, batch_size=reinforce_config.batch_size) + + # run reinforcement training + for batch_id, targets_batch in enumerate(targets_batch_list, start=1): + + # start tree planning simulation for batch of targets + tree_list = run_planning(targets_batch=targets_batch, + tree_config=tree_config, + policy_config=policy_config, + value_config=value_config, + reaction_rules_path=reaction_rules_path, + building_blocks_path=building_blocks_path, + targets_batch_id=batch_id) + + # extract pos and neg retrons from the list of built trees + extracted_retrons = extract_tree_retrons(tree_list) + + # TODO there is a problem with batch size in lightning + # train value network for extracted retrons + run_training(extracted_retrons=extracted_retrons, value_config=value_config) diff --git a/SynTool/ml/training/supervised.py b/SynTool/ml/training/supervised.py new file mode 100755 index 0000000000000000000000000000000000000000..5ced08a7ae0f2ff93e779e5bb105860f43df537c --- /dev/null +++ b/SynTool/ml/training/supervised.py @@ -0,0 +1,139 @@ +""" +Module for the preparation and training of a policy network used in the expansion of nodes in Monte Carlo Tree Search (MCTS). +This module includes functions for creating training datasets and running the training process for the policy network. +""" + +import warnings +from pathlib import Path + +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import CSVLogger +from torch.utils.data import random_split +from torch_geometric.data.lightning import LightningDataset + +from SynTool.ml.networks.policy import PolicyNetwork +from SynTool.utils.config import PolicyNetworkConfig +from SynTool.ml.training.preprocessing import RankingPolicyDataset, FilteringPolicyDataset +from SynTool.utils.logging import DisableLogger, HiddenPrints + +warnings.filterwarnings("ignore") + + +def create_policy_dataset( + reaction_rules_path: str, + molecules_or_reactions_path: str, + output_path: str, + dataset_type: str = "filtering", + batch_size: int = 100, + num_cpus: int = 1, + training_data_ratio: float = 0.8, +): + """ + Generic function to create a training dataset for a policy network. + + :param dataset_type: Type of the dataset to be created ('ranking' or 'filtering'). + :param reaction_rules_path: Path to the reaction rules file. + :param molecules_or_reactions_path: Path to the molecules or reactions file. + :param output_path: Path to store the processed dataset. + :param batch_size: Size of each data batch. + :param num_cpus: Number of CPUs to use for data processing. + :param training_data_ratio: Ratio of training data to total data. + :return: A `LightningDataset` object containing training and validation datasets. + """ + with DisableLogger(): + if dataset_type == "filtering": + full_dataset = FilteringPolicyDataset( + reaction_rules_path=reaction_rules_path, + molecules_path=molecules_or_reactions_path, + output_path=output_path, + num_cpus=num_cpus, + ) + elif dataset_type == "ranking": + full_dataset = RankingPolicyDataset( + reaction_rules_path=reaction_rules_path, + reactions_path=molecules_or_reactions_path, + output_path=output_path, + ) + else: + raise ValueError("Invalid dataset type. Must be 'ranking' or 'filtering'.") + + train_size = int(training_data_ratio * len(full_dataset)) + val_size = len(full_dataset) - train_size + + train_dataset, val_dataset = random_split( + full_dataset, [train_size, val_size], torch.Generator().manual_seed(42) + ) + print(f"Training set size: {len(train_dataset)}, validation set size: {len(val_dataset)}") + + datamodule = LightningDataset( + train_dataset, + val_dataset, + batch_size=batch_size, + pin_memory=True, + drop_last=True, + ) + return datamodule + + +def run_policy_training( + datamodule: LightningDataset, + config: PolicyNetworkConfig, + results_path: str, + accelerator: str = "gpu", +): + """ + Trains a policy network using a given datamodule and training configuration. + + :param datamodule: A PyTorch Lightning `DataModule` class instance. It is responsible for + loading, processing, and preparing the training data for the model. + :param config: The dictionary that contains various configuration settings for the policy training process. + :param results_path: Path to store the training results and logs. + :param accelerator: The type of hardware accelerator to use for training (e.g., 'gpu', 'cpu'). + Defaults to "gpu". + :param devices: A list of device indices to use for training. Defaults to [0]. + :param silent: If True (the default) all logging information will be not printed + + This function sets up the environment for training a policy network. It includes creating directories + for storing logs and weights, initializing the network with the specified configuration, and setting up + training callbacks like LearningRateMonitor and ModelCheckpoint. The Trainer from PyTorch Lightning is + used to manage the training process. If 'silent' is set to True, the function suppresses the standard + output and logging information during training. + + The function creates three subdirectories within the specified 'results_path': + - 'logs/' for storing training logs. + - 'weights/' for saving model checkpoints. + """ + results_path = Path(results_path) + results_path.mkdir(exist_ok=True) + + weights_path = results_path.joinpath("policy_network.ckpt") + + network = PolicyNetwork( + vector_dim=config.vector_dim, + n_rules=datamodule.train_dataset.dataset.num_classes, + batch_size=config.batch_size, + dropout=config.dropout, + num_conv_layers=config.num_conv_layers, + learning_rate=config.learning_rate, + policy_type=config.policy_type, + ) + + lr_monitor = LearningRateMonitor(logging_interval="epoch") + with DisableLogger(), HiddenPrints(): + trainer = Trainer( + accelerator=accelerator, + devices=[0], + max_epochs=config.num_epoch, + logger=False, + gradient_clip_val=1.0, + enable_checkpointing=False, + enable_progress_bar=False + ) + + trainer.fit(network, datamodule) + ba = round(trainer.logged_metrics['train_balanced_accuracy_y_step'].item(), 3) + trainer.save_checkpoint(weights_path) + + print(f'Policy network balanced accuracy: {ba}') diff --git a/SynTool/utils/__init__.py b/SynTool/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..7bb246ea1b257c954dc19536666c3e8d0e983619 --- /dev/null +++ b/SynTool/utils/__init__.py @@ -0,0 +1,4 @@ +from typing import Union +from os import PathLike + +path_type = Union[str, PathLike] diff --git a/SynTool/utils/__pycache__/__init__.cpython-310.pyc b/SynTool/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..feed20e9bba43bba328ec3d9e15057db41d55fbd Binary files /dev/null and b/SynTool/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/SynTool/utils/__pycache__/config.cpython-310.pyc b/SynTool/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6af44a84535d982744246d9b08fb5bd9c4d04e46 Binary files /dev/null and b/SynTool/utils/__pycache__/config.cpython-310.pyc differ diff --git a/SynTool/utils/__pycache__/files.cpython-310.pyc b/SynTool/utils/__pycache__/files.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36d23f81dd96b138b0740b098e7aab98e8b823e3 Binary files /dev/null and b/SynTool/utils/__pycache__/files.cpython-310.pyc differ diff --git a/SynTool/utils/__pycache__/loading.cpython-310.pyc b/SynTool/utils/__pycache__/loading.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e0a73795ba55c081db1372dfabd609ac5812b6b Binary files /dev/null and b/SynTool/utils/__pycache__/loading.cpython-310.pyc differ diff --git a/SynTool/utils/__pycache__/logging.cpython-310.pyc b/SynTool/utils/__pycache__/logging.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..196307c913b5f324d2c49a4b52515ab36f79d082 Binary files /dev/null and b/SynTool/utils/__pycache__/logging.cpython-310.pyc differ diff --git a/SynTool/utils/config.py b/SynTool/utils/config.py new file mode 100755 index 0000000000000000000000000000000000000000..a3c5db3af3e3b2472f78467941d5dde81646f021 --- /dev/null +++ b/SynTool/utils/config.py @@ -0,0 +1,638 @@ +""" +Module containing training and planning configuration dictionaries +""" + +from abc import ABC, abstractmethod +from typing import List, Union +import yaml +from CGRtools.containers import MoleculeContainer, QueryContainer +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict + + +@dataclass +class ConfigABC(ABC): + """ + Abstract base class for configuration classes. + """ + + @staticmethod + @abstractmethod + def from_dict(config_dict: Dict[str, Any]): + """ + Create an instance of the configuration from a dictionary. + """ + pass + + @staticmethod + @abstractmethod + def from_yaml(file_path: str): + """ + Deserialize a YAML file into a configuration object. + """ + pass + + @abstractmethod + def _validate_params(self, params: Dict[str, Any]): + """ + Validate configuration parameters. + """ + pass + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the configuration into a dictionary. + """ + return {k: str(v) if isinstance(v, Path) else v for k, v in self.__dict__.items()} + + def to_yaml(self, file_path: str): + """ + Serialize the configuration to a YAML file. + + Args: + file_path: Path where the YAML file will be saved. + """ + with open(file_path, "w") as file: + yaml.dump(self.to_dict(), file) + + def __post_init__(self): + # Call _validate_params method after initialization + params = self.to_dict() # Convert the current instance to a dictionary + self._validate_params(params) + + +@dataclass +class ReactionStandardizationConfig(ConfigABC): + """ + Configuration class for standardizing reactions. + + :ivar ignore_mapping: + :ivar skip_errors: + :ivar keep_unbalanced_ions: + :ivar keep_reagents: + :ivar action_on_isotopes: + """ + + ignore_mapping: bool = True + skip_errors: bool = True + keep_unbalanced_ions: bool = False + keep_reagents: bool = False + action_on_isotopes: bool = False + + def __post_init__(self): + super().__post_init__() + self._validate_params(self.to_dict()) + + @staticmethod + def from_dict(config_dict: Dict[str, Any]): + """ + Creates an ReactionStandardizationConfig instance from a dictionary of configuration parameters. + + :ivar config_dict: A dictionary containing configuration parameters. + :return: An instance of ReactionStandardizationConfig. + """ + return ReactionStandardizationConfig(**config_dict) + + @staticmethod + def from_yaml(file_path: str): + """ + Deserializes a YAML file into an ReactionStandardizationConfig object. + + :ivar file_path: Path to the YAML file containing configuration parameters. + :return: An instance of ReactionStandardizationConfig. + """ + with open(file_path, "r") as file: + config_dict = yaml.safe_load(file) + return ReactionStandardizationConfig.from_dict(config_dict) + + def _validate_params(self, params: Dict[str, Any]): + """ + Validate the parameters of the configuration. + """ + if not isinstance(params["ignore_mapping"], bool): + raise ValueError("ignore_mapping must be a boolean.") + + if not isinstance(params["skip_errors"], bool): + raise ValueError("skip_errors must be a boolean.") + + if not isinstance(params["keep_unbalanced_ions"], bool): + raise ValueError("keep_unbalanced_ions must be a boolean.") + + if not isinstance(params["keep_reagents"], bool): + raise ValueError("keep_reagents must be a boolean.") + + if not isinstance(params["action_on_isotopes"], bool): + raise ValueError("action_on_isotopes must be a boolean.") + + +@dataclass +class RuleExtractionConfig(ConfigABC): + """ + Configuration class for extracting reaction rules, inheriting from ConfigABC. + + :ivar multicenter_rules: If True, extracts a single rule encompassing all centers. + If False, extracts separate reaction rules for each reaction center in a multicenter reaction. + :ivar as_query_container: If True, the extracted rules are generated as QueryContainer objects, + analogous to SMARTS objects for pattern matching in chemical structures. + :ivar reverse_rule: If True, reverses the direction of the reaction for rule extraction. + :ivar reactor_validation: If True, validates each generated rule in a chemical reactor to ensure correct + generation of products from reactants. + :ivar include_func_groups: If True, includes specific functional groups in the reaction rule in addition + to the reaction center and its environment. + :ivar func_groups_list: A list of functional groups to be considered when include_func_groups is True. + :ivar include_rings: If True, includes ring structures in the reaction rules. + :ivar keep_leaving_groups: If True, retains leaving groups in the extracted reaction rule. + :ivar keep_incoming_groups: If True, retains incoming groups in the extracted reaction rule. + :ivar keep_reagents: If True, includes reagents in the extracted reaction rule. + :ivar environment_atom_count: Defines the size of the environment around the reaction center to be included + in the rule (0 for only the reaction center, 1 for the first environment, etc.). + :ivar min_popularity: Minimum number of times a rule must be applied to be considered for further analysis. + :ivar keep_metadata: If True, retains metadata associated with the reaction in the extracted rule. + :ivar single_reactant_only: If True, includes only reaction rules with a single reactant molecule. + :ivar atom_info_retention: Controls the amount of information about each atom to retain ('none', + 'reaction_center', or 'all'). + """ + + multicenter_rules: bool = True + as_query_container: bool = True + reverse_rule: bool = True + reactor_validation: bool = True + include_func_groups: bool = False + func_groups_list: List[Union[MoleculeContainer, QueryContainer]] = field(default_factory=list) + include_rings: bool = False + keep_leaving_groups: bool = True + keep_incoming_groups: bool = False + keep_reagents: bool = False + environment_atom_count: int = 1 + min_popularity: int = 3 + keep_metadata: bool = False + single_reactant_only: bool = True + atom_info_retention: Dict[str, Dict[str, bool]] = field(default_factory=dict) + + def __post_init__(self): + super().__post_init__() + self._validate_params(self.to_dict()) + self._initialize_default_atom_info_retention() + + def _initialize_default_atom_info_retention(self): + default_atom_info = { + "reaction_center": { + "neighbors": True, + "hybridization": True, + "implicit_hydrogens": True, + "ring_sizes": True, + }, + "environment": { + "neighbors": True, + "hybridization": True, + "implicit_hydrogens": True, + "ring_sizes": True, + }, + } + + if not self.atom_info_retention: + self.atom_info_retention = default_atom_info + else: + for key in default_atom_info: + self.atom_info_retention[key].update( + self.atom_info_retention.get(key, {}) + ) + + @staticmethod + def from_dict(config_dict: Dict[str, Any]): + """ + Creates an ExtractRuleConfig instance from a dictionary of configuration parameters. + + :ivar config_dict: A dictionary containing configuration parameters. + :return: An instance of ExtractRuleConfig. + """ + return RuleExtractionConfig(**config_dict) + + @staticmethod + def from_yaml(file_path: str): + """ + Deserializes a YAML file into an ExtractRuleConfig object. + + :ivar file_path: Path to the YAML file containing configuration parameters. + :return: An instance of ExtractRuleConfig. + """ + with open(file_path, "r") as file: + config_dict = yaml.safe_load(file) + return RuleExtractionConfig.from_dict(config_dict) + + def _validate_params(self, params: Dict[str, Any]): + """ + Validate the parameters of the configuration. + """ + if not isinstance(params["multicenter_rules"], bool): + raise ValueError("multicenter_rules must be a boolean.") + + if not isinstance(params["as_query_container"], bool): + raise ValueError("as_query_container must be a boolean.") + + if not isinstance(params["reverse_rule"], bool): + raise ValueError("reverse_rule must be a boolean.") + + if not isinstance(params["reactor_validation"], bool): + raise ValueError("reactor_validation must be a boolean.") + + if not isinstance(params["include_func_groups"], bool): + raise ValueError("include_func_groups must be a boolean.") + + if params["func_groups_list"] is not None and not all( + isinstance(group, (MoleculeContainer, QueryContainer)) + for group in params["func_groups_list"] + ): + raise ValueError( + "func_groups_list must be a list of MoleculeContainer or QueryContainer objects." + ) + + if not isinstance(params["include_rings"], bool): + raise ValueError("include_rings must be a boolean.") + + if not isinstance(params["keep_leaving_groups"], bool): + raise ValueError("keep_leaving_groups must be a boolean.") + + if not isinstance(params["keep_incoming_groups"], bool): + raise ValueError("keep_incoming_groups must be a boolean.") + + if not isinstance(params["keep_reagents"], bool): + raise ValueError("keep_reagents must be a boolean.") + + if not isinstance(params["environment_atom_count"], int): + raise ValueError("environment_atom_count must be an integer.") + + if not isinstance(params["min_popularity"], int): + raise ValueError("min_popularity must be an integer.") + + if not isinstance(params["keep_metadata"], bool): + raise ValueError("keep_metadata must be a boolean.") + + if not isinstance(params["single_reactant_only"], bool): + raise ValueError("single_reactant_only must be a boolean.") + + if params["atom_info_retention"] is not None: + if not isinstance(params["atom_info_retention"], dict): + raise ValueError("atom_info_retention must be a dictionary.") + + required_keys = {"reaction_center", "environment"} + if not required_keys.issubset(params["atom_info_retention"]): + missing_keys = required_keys - set(params["atom_info_retention"].keys()) + raise ValueError( + f"atom_info_retention missing required keys: {missing_keys}" + ) + + for key, value in params["atom_info_retention"].items(): + if key not in required_keys: + raise ValueError(f"Unexpected key in atom_info_retention: {key}") + + expected_subkeys = { + "neighbors", + "hybridization", + "implicit_hydrogens", + "ring_sizes", + } + if not isinstance(value, dict) or not expected_subkeys.issubset(value): + missing_subkeys = expected_subkeys - set(value.keys()) + raise ValueError( + f"Invalid structure for {key} in atom_info_retention. Missing subkeys: {missing_subkeys}" + ) + + for subkey, subvalue in value.items(): + if not isinstance(subvalue, bool): + raise ValueError( + f"Value for {subkey} in {key} of atom_info_retention must be boolean." + ) + + +@dataclass +class TreeConfig(ConfigABC): + """ + Configuration class for the tree-based search algorithm, inheriting from ConfigABC. + + :ivar max_iterations: The number of iterations to run the algorithm for, defaults to 100. + :ivar max_tree_size: The maximum number of nodes in the tree, defaults to 10000. + :ivar max_time: The time limit (in seconds) for the algorithm to run, defaults to 600. + :ivar max_depth: The maximum depth of the tree, defaults to 6. + :ivar ucb_type: Type of UCB used in the search algorithm. Options are "puct", "uct", "value", defaults to "uct". + :ivar c_ucb: The exploration-exploitation balance coefficient used in Upper Confidence Bound (UCB), defaults to 0.1. + :ivar backprop_type: Type of backpropagation algorithm. Options are "muzero", "cumulative", defaults to "muzero". + :ivar search_strategy: The strategy used for tree search. Options are "expansion_first", "evaluation_first", defaults to "expansion_first". + :ivar exclude_small: Whether to exclude small molecules during the search, defaults to True. + :ivar evaluation_agg: Method for aggregating evaluation scores. Options are "max", "average", defaults to "max". + :ivar evaluation_type: The method used for evaluating nodes. Options are "random", "rollout", "gcn", defaults to "gcn". + :ivar init_node_value: Initial value for a new node, defaults to 0.0. + :ivar epsilon: A parameter in the epsilon-greedy search strategy representing the chance of random selection + of reaction rules during the selection stage in Monte Carlo Tree Search, + specifically during Upper Confidence Bound estimation. + It balances between exploration and exploitation, defaults to 0.0. + :ivar min_mol_size: Defines the minimum size of a molecule that is have to be synthesized. + Molecules with 6 or fewer heavy atoms are assumed to be building blocks by definition, + thus setting the threshold for considering larger molecules in the search, defaults to 6. + :ivar silent: Whether to suppress progress output, defaults to False. + """ + + max_iterations: int = 100 + max_tree_size: int = 10000 + max_time: float = 600 + max_depth: int = 6 + ucb_type: str = "uct" + c_ucb: float = 0.1 + backprop_type: str = "muzero" + search_strategy: str = "expansion_first" + exclude_small: bool = True + evaluation_agg: str = "max" + evaluation_type: str = "gcn" + init_node_value: float = 0.0 + epsilon: float = 0.0 + min_mol_size: int = 6 + silent: bool = False + + @staticmethod + def from_dict(config_dict: Dict[str, Any]): + """ + Creates a TreeConfig instance from a dictionary of configuration parameters. + + Args: + config_dict: A dictionary containing configuration parameters. + + Returns: + An instance of TreeConfig. + """ + return TreeConfig(**config_dict) + + @staticmethod + def from_yaml(file_path: str): + """ + Deserializes a YAML file into a TreeConfig object. + + Args: + file_path: Path to the YAML file containing configuration parameters. + + Returns: + An instance of TreeConfig. + """ + with open(file_path, "r") as file: + config_dict = yaml.safe_load(file) + return TreeConfig.from_dict(config_dict) + + def _validate_params(self, params): + if params["ucb_type"] not in ["puct", "uct", "value"]: + raise ValueError( + "Invalid ucb_type. Allowed values are 'puct', 'uct', 'value'." + ) + if params["backprop_type"] not in ["muzero", "cumulative"]: + raise ValueError( + "Invalid backprop_type. Allowed values are 'muzero', 'cumulative'." + ) + if params["evaluation_type"] not in ["random", "rollout", "gcn"]: + raise ValueError( + "Invalid evaluation_type. Allowed values are 'random', 'rollout', 'gcn'." + ) + if params["evaluation_agg"] not in ["max", "average"]: + raise ValueError( + "Invalid evaluation_agg. Allowed values are 'max', 'average'." + ) + if not isinstance(params["c_ucb"], float): + raise TypeError("c_ucb must be a float.") + if not isinstance(params["max_depth"], int) or params["max_depth"] < 1: + raise ValueError("max_depth must be a positive integer.") + if not isinstance(params["max_tree_size"], int) or params["max_tree_size"] < 1: + raise ValueError("max_tree_size must be a positive integer.") + if ( + not isinstance(params["max_iterations"], int) + or params["max_iterations"] < 1 + ): + raise ValueError("max_iterations must be a positive integer.") + if not isinstance(params["max_time"], int) or params["max_time"] < 1: + raise ValueError("max_time must be a positive integer.") + if not isinstance(params["silent"], bool): + raise TypeError("silent must be a boolean.") + if not isinstance(params["init_node_value"], float): + raise TypeError("init_node_value must be a float if provided.") + if params["search_strategy"] not in ["expansion_first", "evaluation_first"]: + raise ValueError( + f"Invalid search_strategy: {params['search_strategy']}: " + f"Allowed values are 'expansion_first', 'evaluation_first'" + ) + + +@dataclass +class PolicyNetworkConfig(ConfigABC): + """ + Configuration class for the policy network, inheriting from ConfigABC. + + :ivar vector_dim: dimension of the input vectors. + :ivar batch_size: number of samples per batch. + :ivar dropout: dropout rate for regularization. + :ivar learning_rate: learning rate for the optimizer. + :ivar num_conv_layers: number of convolutional layers in the network. + :ivar num_epoch: number of training epochs. + :ivar policy_type: mode of operation, either 'filtering' or 'ranking'. + """ + + policy_type: str = "ranking" + vector_dim: int = 256 + batch_size: int = 500 + dropout: float = 0.4 + learning_rate: float = 0.008 + num_conv_layers: int = 5 + num_epoch: int = 100 + weights_path: str = None + + # for filtering policy + priority_rules_fraction: float = 0.5 + rule_prob_threshold: float = 0.0 + top_rules: int = 50 + + @staticmethod + def from_dict(config_dict: Dict[str, Any]) -> 'PolicyNetworkConfig': + """ + Creates a PolicyNetworkConfig instance from a dictionary of configuration parameters. + + :param config_dict: A dictionary containing configuration parameters. + :return: An instance of PolicyNetworkConfig. + """ + return PolicyNetworkConfig(**config_dict) + + @staticmethod + def from_yaml(file_path: str) -> 'PolicyNetworkConfig': + """ + Deserializes a YAML file into a PolicyNetworkConfig object. + + :param file_path: Path to the YAML file containing configuration parameters. + :return: An instance of PolicyNetworkConfig. + """ + with open(file_path, 'r') as file: + config_dict = yaml.safe_load(file) + return PolicyNetworkConfig.from_dict(config_dict) + + def _validate_params(self, params: Dict[str, Any]): + + if params['policy_type'] not in ["filtering", "ranking"]: + raise ValueError("policy_type must be either 'filtering' or 'ranking'.") + + if not isinstance(params['vector_dim'], int) or params['vector_dim'] <= 0: + raise ValueError("vector_dim must be a positive integer.") + + if not isinstance(params['batch_size'], int) or params['batch_size'] <= 0: + raise ValueError("batch_size must be a positive integer.") + + if not isinstance(params['num_conv_layers'], int) or params['num_conv_layers'] <= 0: + raise ValueError("num_conv_layers must be a positive integer.") + + if not isinstance(params['num_epoch'], int) or params['num_epoch'] <= 0: + raise ValueError("num_epoch must be a positive integer.") + + if not isinstance(params['dropout'], float) or not (0.0 <= params['dropout'] <= 1.0): + raise ValueError("dropout must be a float between 0.0 and 1.0.") + + if not isinstance(params['learning_rate'], float) or params['learning_rate'] <= 0.0: + raise ValueError("learning_rate must be a positive float.") + + if not isinstance(params['priority_rules_fraction'], float) or params['priority_rules_fraction'] < 0.0: + raise ValueError("priority_rules_fraction must be a non-negative positive float.") + + if not isinstance(params['rule_prob_threshold'], float) or params['rule_prob_threshold'] < 0.0: + raise ValueError("rule_prob_threshold must be a non-negative float.") + + if not isinstance(params['top_rules'], int) or params['top_rules'] <= 0: + raise ValueError("top_rules must be a positive integer.") + + +@dataclass +class ValueNetworkConfig(ConfigABC): + """ + Configuration class for the value network, inheriting from ConfigABC. + + :ivar vector_dim: Dimension of the input vectors. + :ivar batch_size: Number of samples per batch. + :ivar dropout: Dropout rate for regularization. + :ivar learning_rate: Learning rate for the optimizer. + :ivar num_conv_layers: Number of convolutional layers in the network. + :ivar num_epoch: Number of training epochs. + """ + + weights_path: str = None + vector_dim: int = 256 + batch_size: int = 500 + dropout: float = 0.4 + learning_rate: float = 0.008 + num_conv_layers: int = 5 + num_epoch: int = 100 + + @staticmethod + def from_dict(config_dict: Dict[str, Any]) -> 'ValueNetworkConfig': + """ + Creates a ValueNetworkConfig instance from a dictionary of configuration parameters. + + :ivar config_dict: A dictionary containing configuration parameters. + :return: An instance of ValueNetworkConfig. + """ + return ValueNetworkConfig(**config_dict) + + @staticmethod + def from_yaml(file_path: str) -> 'ValueNetworkConfig': + """ + Deserializes a YAML file into a ValueNetworkConfig object. + + :ivar file_path: Path to the YAML file containing configuration parameters. + :return: An instance of ValueNetworkConfig. + """ + with open(file_path, 'r') as file: + config_dict = yaml.safe_load(file) + return ValueNetworkConfig.from_dict(config_dict) + + def to_yaml(self, file_path: str): + """ + Serializes the configuration to a YAML file. + + :ivar file_path: Path to the YAML file for serialization. + """ + with open(file_path, 'w') as file: + yaml.dump(self.to_dict(), file) + + def _validate_params(self, params: Dict[str, Any]): + """ + Validates the configuration parameters. + + :ivar params: A dictionary of parameters to validate. + :raises ValueError: If any parameter is invalid. + """ + if not isinstance(params['vector_dim'], int) or params['vector_dim'] <= 0: + raise ValueError("vector_dim must be a positive integer.") + + if not isinstance(params['batch_size'], int) or params['batch_size'] <= 0: + raise ValueError("batch_size must be a positive integer.") + + if not isinstance(params['num_conv_layers'], int) or params['num_conv_layers'] <= 0: + raise ValueError("num_conv_layers must be a positive integer.") + + if not isinstance(params['num_epoch'], int) or params['num_epoch'] <= 0: + raise ValueError("num_epoch must be a positive integer.") + + if not isinstance(params['dropout'], float) or not (0.0 <= params['dropout'] <= 1.0): + raise ValueError("dropout must be a float between 0.0 and 1.0.") + + if not isinstance(params['learning_rate'], float) or params['learning_rate'] <= 0.0: + raise ValueError("learning_rate must be a positive float.") + + +@dataclass +class ReinforcementConfig(ConfigABC): + """ + Configuration class for the reinforcement network training, inheriting from ConfigABC. + + :ivar batch_size: the number of samples per batch. + :ivar num_simulations: the number of num_simulations. + """ + + batch_size: int = 100 + num_simulations: int = 1 + + @staticmethod + def from_dict(config_dict: Dict[str, Any]) -> 'ReinforcementConfig': + """ + Creates a ReinforcementConfig instance from a dictionary of configuration parameters. + + :param config_dict: A dictionary containing configuration parameters. + :return: An instance of ReinforcementConfig. + """ + return ReinforcementConfig(**config_dict) + + @staticmethod + def from_yaml(file_path: str) -> 'ReinforcementConfig': + """ + Deserializes a YAML file into a ReinforcementConfig object. + + :param file_path: Path to the YAML file containing configuration parameters. + :return: An instance of ReinforcementConfig. + """ + with open(file_path, 'r') as file: + config_dict = yaml.safe_load(file) + return ReinforcementConfig.from_dict(config_dict) + + def _validate_params(self, params: Dict[str, Any]): + + if not isinstance(params['batch_size'], int) or params['batch_size'] <= 0: + raise ValueError("batch_size must be a positive integer.") + + if not isinstance(params['num_simulations'], int) or params['num_simulations'] <= 0: + raise ValueError("num_simulations must be a positive integer.") + + +def convert_config_to_dict(config_attr, config_type): + """ + Converts a configuration attribute to a dictionary if it's either a dictionary + or an instance of a specified configuration type. + + :param config_attr: The configuration attribute to be converted. + :param config_type: The type to check against for conversion. + :return: The configuration attribute as a dictionary, or None if it's not an instance of the given type or dict. + """ + if isinstance(config_attr, dict): + return config_attr + elif isinstance(config_attr, config_type): + return config_attr.to_dict() + return None diff --git a/SynTool/utils/files.py b/SynTool/utils/files.py new file mode 100644 index 0000000000000000000000000000000000000000..67a04a613b984c47229d5a7285c25fd9ccc16010 --- /dev/null +++ b/SynTool/utils/files.py @@ -0,0 +1,253 @@ +from pathlib import Path +from os.path import splitext +from typing import Iterable, Union + +from CGRtools import smiles +from CGRtools.containers import ReactionContainer, MoleculeContainer, CGRContainer, QueryContainer +from CGRtools.files.SDFrw import SDFRead, SDFWrite +from CGRtools.files.RDFrw import RDFRead, RDFWrite + +from SynTool.chem.utils import to_reaction_smiles_record +from SynTool.utils import path_type + + +class SMILESRead: + def __init__(self, filename: path_type, **kwargs): + """ + Simplified class to read files containing a SMILES (Molecules or Reaction) string per line. + :param filename: the path and name of the SMILES file to parse + :return: None + """ + filename = str(Path(filename).resolve(strict=True)) + self._file = open(filename, "r") + self._data = self.__data() + self._len = sum(1 for _ in open(filename, "r")) #TODO replace later + + def __data(self) -> Iterable[Union[ReactionContainer, CGRContainer, MoleculeContainer]]: + for line in iter(self._file.readline, ''): + line = line.strip() + x = smiles(line) + if isinstance(x, (ReactionContainer, CGRContainer, MoleculeContainer)): + x.meta['init_smiles'] = line + yield x + + def __enter__(self): + return self + + def read(self): + """ + Parse whole SMILES file. + + :return: List of parsed molecules or reactions. + """ + return list(iter(self)) + + def __iter__(self): + return (x for x in self._data) + + def __next__(self): + return next(iter(self)) + + def __len__(self): + return self._len + + def close(self): + self._file.close() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class FileHandler: + def __init__(self, filename: path_type, **kwargs): + """ + General class to handle chemical files. + + :param filename: the path and name of the file + :type filename: path_type + + :return: None + """ + self._file = None + # filename = str(Path(filename).resolve(strict=True)) #TODO Tagir please correct bug in ReactionWriter following your modification + _, ext = splitext(filename) + file_types = { + '.smi': "SMI", + '.smiles': "SMI", + '.rdf': "RDF", + '.sdf': 'SDF', + } + try: + self._file_type = file_types[ext] + except KeyError: + raise ValueError("I don't know the file extension,", ext) + + def close(self): + self._file.close() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class Reader(FileHandler): + def __init__(self, filename: path_type, **kwargs): + """ + General class to read chemical files. + + :param filename: the path and name of the file + :type filename: path_type + + :return: None + """ + super().__init__(filename, **kwargs) + + def __enter__(self): + return self._file + + def __iter__(self): + return iter(self._file) + + def __next__(self): + return next(self._file) + + def __len__(self): + return len(self._file) + + +class Writer(FileHandler): + def __init__(self, filename: path_type, mapping: bool = True, **kwargs): + """ + General class to write chemical files. + + :param filename: the path and name of the file + :type filename: path_type + + :param mapping: whenever to save mapping or not + :type mapping: bool + + :return: None + """ + super().__init__(filename, **kwargs) + self._mapping = mapping + + def __enter__(self): + return self + + +class ReactionReader(Reader): + def __init__(self, filename: path_type, **kwargs): + """ + Class to read reaction files. + + :param filename: the path and name of the file + :type filename: path_type + + :return: None + """ + super().__init__(filename, **kwargs) + if self._file_type == "SMI": + self._file = SMILESRead(filename, **kwargs) + elif self._file_type == "RDF": + self._file = RDFRead(filename, indexable=True, **kwargs) + else: + raise ValueError("File type incompatible -", filename) + + +class ReactionWriter(Writer): + def __init__(self, filename: path_type, append_results: bool = False, mapping: bool = True, **kwargs): + """ + Class to write reaction files. + + :param filename: the path and name of the file + :type filename: path_type + + :param append_results: whenever to append the new reactions (True) or to override the file (False) + :type append_results: bool + + :param mapping: whenever to save mapping or not + :type mapping: bool + + :return: None + """ + super().__init__(filename, mapping, **kwargs) + if self._file_type == "SMI": + open_mode = "a" if append_results else "w" + self._file = open(filename, open_mode, **kwargs) + elif self._file_type == "RDF": + self._file = RDFWrite(filename, append=append_results, **kwargs) + else: + raise ValueError("File type incompatible -", filename) + + def write(self, reaction: ReactionContainer): + """ + Function to write a specific reaction to the file. + + :param reaction: the path and name of the file + :type reaction: ReactionContainer + + :return: None + """ + if self._file_type == "SMI": + rea_str = to_reaction_smiles_record(reaction) + self._file.write(rea_str + "\n") + elif self._file_type == "RDF": + self._file.write(reaction) + + +class MoleculeReader(Reader): + def __init__(self, filename: path_type, **kwargs): + """ + Class to read molecule files. + + :param filename: the path and name of the file + + :return: None + """ + super().__init__(filename, **kwargs) + if self._file_type == "SMI": + self._file = SMILESRead(filename, ignore=True, **kwargs) + elif self._file_type == "SDF": + self._file = SDFRead(filename, indexable=True, **kwargs) + else: + raise ValueError("File type incompatible -", filename) + + +class MoleculeWriter(Writer): + def __init__(self, filename: path_type, append_results: bool = False, mapping: bool = True, **kwargs): + """ + Class to write molecule files. + + :param filename: the path and name of the file + :type filename: path_type + + :param append_results: whenever to append the new molecules (True) or to override the file (False) + :type append_results: bool + + :param mapping: whenever to save mapping or not + :type mapping: bool + + :return: None + """ + super().__init__(filename, mapping, **kwargs) + if self._file_type == "SMI": + open_mode = "a" if append_results else "w" + self._file = open(filename, open_mode, **kwargs) + elif self._file_type == "SDF": + self._file = SDFWrite(filename, append=append_results, **kwargs) + else: + raise ValueError("File type incompatible -", filename) + + def write(self, molecule): + """ + Function to write a specific molecule to the file. + + :param molecule: the path and name of the file + :type molecule: MoleculeContainer | CGRContainer | QueryContainer + + :return: None + """ + if self._file_type == "SMI": + mol_str = to_reaction_smiles_record(molecule) + self._file.write(mol_str + "\n") + elif self._file_type == "SDF": + self._file.write(molecule) diff --git a/SynTool/utils/loading.py b/SynTool/utils/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..f474953ff884a8afbb366e15ee98e5e3e357d0e1 --- /dev/null +++ b/SynTool/utils/loading.py @@ -0,0 +1,139 @@ +""" +Module containing functions for loading retrosynthetic models and files +""" + +import functools +import logging +import pickle +from time import time +from tqdm import tqdm + +import pandas as pd + +from CGRtools import SMILESRead, smiles +from CGRtools.reactor import Reactor +from werkzeug.datastructures import FileStorage +from werkzeug.utils import secure_filename +from torch import device + + +@functools.lru_cache(maxsize=None) +def load_reaction_rules(file): + """ + The function loads reaction rules from a pickle file and converts them into a list of Reactor objects if necessary + + :param file: The path to the pickle file that stores the reaction rules + :return: A list of reaction rules + """ + with open(file, "rb") as f: + reaction_rules = pickle.load(f) + + if not isinstance(reaction_rules[0][0], Reactor): + reaction_rules = [Reactor(x) for x, _ in reaction_rules] + + return reaction_rules + + +def standardize_building_blocks(input_file, output_file): # TODO implement with reader/writer + """ + Canonicalizes custom building blocks. + + :param input_file: The path to the txt file that stores the original building blocks + :param output_file: The path to the txt file that stores the canonicalazied building blocks + """ + + with open(input_file, "r") as inp_file, open(output_file, "w") as out_file: + for smi in tqdm(inp_file): + mol = smiles(smi) + try: + mol.canonicalize() + except: + continue + out_file.write(f'{str(mol)}\n') + + return output_file + + +@functools.lru_cache(maxsize=None) +def load_building_blocks(file: str, canonicalize: bool = False): + """ + Loads building blocks data from a file, either in text, SMILES, or pickle format, and returns a frozen set of + building blocks. + + :param file: The path to the file containing the building blocks data + :param canonicalize: The `canonicalize` parameter determines whether the loaded building blocks should be + canonicalized or not + :return: The frozen set loaded building blocks + """ + if not file: + logging.warning("No external In-Stock data was loaded") + return None + + start = time() + if isinstance(file, FileStorage): + filename = secure_filename(file.filename) + if filename.endswith(".pickle") or filename.endswith(".pkl"): + bb = pickle.load(file) + elif filename.endswith(".txt") or filename.endswith(".smi"): + bb = set([mol.decode("utf-8") for mol in file]) + else: + raise TypeError( + "content of FileStorage is not appropriate for in-building_blocks dataloader, expected .txt, .smi, .pickle or .pkl" + ) + elif isinstance(file, str): + filetype = file.split(".")[-1] + # Loading in-building_blocks substances data + if filetype in {"txt", "smi", "smiles"}: + with open(file, "r") as file: + if canonicalize: + parser = SMILESRead.create_parser(ignore=True) + mols = [parser(str(mol)) for mol in file] + for mol in mols: + mol.canonicalize() + bb = set([str(mol).strip() for mol in mols]) + else: + bb = set([str(mol).strip() for mol in file]) + elif filetype == "pickle" or filetype == "pkl": + with open(file, "rb") as file: + bb = pickle.load(file) + if isinstance(bb, list): + bb = set(bb) + else: + raise TypeError( + f"expected .txt, .smi, .pickle, or .pkl files, not {filetype}" + ) + + stop = time() + logging.debug(f"{len(bb)} In-Stock Substances are loaded.\nTook {round(stop - start, 2)} seconds.") + return bb + + +def load_value_net(model_class, value_network_path): + """ + Loads a model from an external path or an internal path + + :param value_network_path: + :param model_class: The model class you want to load + :type model_class: pl.LightningModule + model will be loaded from the external path + """ + + map_location = device("cpu") + return model_class.load_from_checkpoint(value_network_path, map_location) + + +def load_policy_net(model_class, policy_network_path): + """ + Loads a model from an external path or an internal path + + :param policy_network_path: + :param model_class: The model class you want to load + :type model_class: pl.LightningModule + model will be loaded from the external path + """ + + map_location = device("cpu") + # return model_class.load_from_checkpoint(policy_network_path, map_location, n_rules=n_rules, + # vector_dim=vector_dim, batch_size=1) + + return model_class.load_from_checkpoint(policy_network_path, map_location, batch_size=1) diff --git a/SynTool/utils/logging.py b/SynTool/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..15821dbfdfe602aa45bc140c283e317be2939bca --- /dev/null +++ b/SynTool/utils/logging.py @@ -0,0 +1,31 @@ +import os +import sys +import logging + + +class DisableLogger: + """ + This function mute redundant logging information. Adopted from + https://stackoverflow.com/questions/2266646/how-to-disable-logging-on-the-standard-error-stream + """ + + def __enter__(self): + logging.disable(logging.CRITICAL) + + def __exit__(self, exit_type, exit_value, exit_traceback): + logging.disable(logging.NOTSET) + + +class HiddenPrints: + """ + This function mute redundant printing information. Adopted from + https://stackoverflow.com/questions/8391411/how-to-block-calls-to-print + """ + + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea38eb25b5716e59067786335eecb5037190f0f --- /dev/null +++ b/app.py @@ -0,0 +1,168 @@ +import streamlit as st +from streamlit_ketcher import st_ketcher + +from SynTool.mcts.tree import Tree, TreeConfig +from SynTool.mcts.expansion import PolicyFunction +from SynTool.mcts.search import extract_tree_stats +from SynTool.utils.config import PolicyNetworkConfig +from SynTool.interfaces.visualisation import to_table, extract_routes +import pickle +import uuid +import base64 +import pandas as pd +import json +import re + + +def download_button(object_to_download, download_filename, button_text, pickle_it=False): + """ + Issued from + Generates a link to download the given object_to_download. + Params: + ------ + object_to_download: The object to be downloaded. + download_filename (str): filename and extension of file. e.g. mydata.csv, + some_txt_output.txt download_link_text (str): Text to display for download + link. + button_text (str): Text to display on download button (e.g. 'click here to download file') + pickle_it (bool): If True, pickle file. + Returns: + ------- + (str): the anchor tag to download object_to_download + Examples: + -------- + download_link(your_df, 'YOUR_DF.csv', 'Click to download data!') + download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!') + """ + if pickle_it: + try: + object_to_download = pickle.dumps(object_to_download) + except pickle.PicklingError as e: + st.write(e) + return None + + else: + if isinstance(object_to_download, bytes): + pass + + elif isinstance(object_to_download, pd.DataFrame): + object_to_download = object_to_download.to_csv(index=False).encode('utf-8') + + # Try JSON encode for everything else + # else: + # object_to_download = json.dumps(object_to_download) + + try: + # some strings <-> bytes conversions necessary here + b64 = base64.b64encode(object_to_download.encode()).decode() + + except AttributeError: + b64 = base64.b64encode(object_to_download).decode() + + button_uuid = str(uuid.uuid4()).replace('-', '') + button_id = re.sub('\d+', '', button_uuid) + + custom_css = f""" + """ + + dl_link = custom_css + f'{button_text}

' + + return dl_link + + +st.set_page_config( # layout="wide", + page_title="SynTool GUI", + page_icon="🧪",) + + +st.title("`SynTool GUI`") +st.write("*{Introduction text to be inserted here}*") +st.header('Molecule input') +st.write("You can provide a molecular structure by either providing its SMILES string + Enter, either by drawing it + Apply.") +DEFAULT_MOL='NC(CCCCB(O)O)(CCN1CCC(CO)C1)C(=O)O' +molecule = st.text_input("Molecule", DEFAULT_MOL) +smile_code = st_ketcher(molecule) + +st.header('Launch calculation') +st.write("If you modified the structure, please ensure you clicked on 'Apply' (bottom right of the molecular editor).") +st.markdown(f"The molecule SMILES is actually: ``{smile_code}``") +max_depth = st.slider('Maximal number of reaction steps', min_value=2, max_value=9, value=9) +run_default = st.button('Launch and search a reaction path',) + +ranking_policy_weights_path = 'data/policy_network.ckpt' +reaction_rules_path = 'data/reaction_rules.pickle' +building_blocks_path = 'data/building_blocks.smi' +policy_config = PolicyNetworkConfig(weights_path=ranking_policy_weights_path) +policy_function = PolicyFunction(policy_config=policy_config) + +if run_default: + st.toast('Optimisation is started. The progress will be printed below') + spinner = st.spinner(text="Running with default parameters...") + tree_config = TreeConfig( + search_strategy="expansion_first", + evaluation_type="rollout", + max_iterations=100, + max_depth=max_depth, + min_mol_size=0, + init_node_value=0.5, + ucb_type="uct", + c_ucb=0.1, + silent=True + ) + + with spinner: + tree = Tree( + target=smile_code, + tree_config=tree_config, + reaction_rules_path=reaction_rules_path, + building_blocks_path=building_blocks_path, + policy_function=policy_function, + value_function=None, + ) + _ = list(tree) + + res = extract_tree_stats(tree, smile_code) # extract_routes(tree) + + st.header('Results') + if res['found_paths']: + st.write("Success!") + st.subheader("Retrosynthetic Routes Report") + st.markdown(to_table(tree, None, extended=True, integration=True), unsafe_allow_html=True) + st.subheader("Statistics") + st.write(pd.DataFrame(res, index=[0])) + st.subheader("Downloads") + dl_html = download_button(to_table(tree, None, extended=True, integration=False), + 'results_syntool.html', + 'Download results as a HTML file') + dl_csv = download_button(pd.DataFrame(res, index=[0]), + 'results_syntool.csv', + 'Download statistics as an Excel csv file') + st.markdown(dl_html+dl_csv, unsafe_allow_html=True) + + else: + st.write("Found no reaction path.") + +st.divider() +st.header('Restart from the beginning?') +if st.button("Restart"): + st.rerun() diff --git a/data/building_blocks.smi b/data/building_blocks.smi new file mode 100644 index 0000000000000000000000000000000000000000..842ce1ad130cffe4cf9c0677efda0984382973af --- /dev/null +++ b/data/building_blocks.smi @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf5cad16ffb3733a3d1c21b5de7be5a4c88431edc9da50e7dcfd543b5ebe1d6d +size 5667442 diff --git a/data/policy_network.ckpt b/data/policy_network.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..f8fe84401cf64844b446630c038c2d6002c96de6 --- /dev/null +++ b/data/policy_network.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15e6af2273534a4ce39b60568f33b50befb436ad5d27ed53bebbaffffbc1c22d +size 151649959 diff --git a/data/reaction_rules.pickle b/data/reaction_rules.pickle new file mode 100644 index 0000000000000000000000000000000000000000..69c20c8a3ab66db29c9f2b74d6e31c301d06c031 --- /dev/null +++ b/data/reaction_rules.pickle @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6e753e4bd935617181eebb9767cef99599fc7edc9f183ab092ae5166936c14b +size 57352695 diff --git a/pre-requirements.txt b/pre-requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e9e1faa206a2f3850f5a5db71f5dc236319f73af --- /dev/null +++ b/pre-requirements.txt @@ -0,0 +1,2 @@ +--find-links https://download.pytorch.org/whl/torch_stable.html +torch==2.1.0+cpu \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f5c8f774d3300dbd74928bc4a05f245d4bc013b0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,29 @@ +streamlit +streamlit_ketcher +CGRtools==4.1.35 +py-mini-racer +pandas>=1.4 +toytree>=2.0 +ray>=2.0 +click>=8.0.0 +StructureFingerprint==2.1 +werkzeug +gdown==4.6.3 +ordered-set==4.1.0 +numpy>=1.26 +chytorch==1.60 +chytorch-rxnmap==1.4 +adabelief-pytorch +scikit-learn==1.4.1.post1 +scipy==1.12.0 +pandas==2.2.1 +altair +pytorch-lightning +torch-cluster +torch-scatter +torch-sparse +torch-spline-conv +torch_geometric +torchmetrics +torchtyping +git+https://github.com/pyg-team/pyg-lib@0.3.0

Retrosynthetic Routes Report