"""Module containing functions for loading reaction rules, building blocks and retrosynthetic models.""" import functools import logging import os from pathlib import Path import pickle import shutil from typing import TYPE_CHECKING, FrozenSet, List, Union import zipfile from CGRtools.files.SDFrw import SDFRead from CGRtools.reactor.reactor import Reactor from huggingface_hub import hf_hub_download, snapshot_download from torch import device from tqdm.auto import tqdm from synplan.chem.utils import _standardize_sdf_text, _standardize_smiles_batch from synplan.ml.networks.policy import PolicyNetwork from synplan.ml.networks.value import ValueNetwork from synplan.utils.files import ( count_sdf_records, count_smiles_records, iter_csv_smiles, iter_csv_smiles_blocks, iter_sdf_text_blocks, iter_smiles, iter_smiles_blocks, ) from synplan.utils.parallel import process_pool_map_stream if TYPE_CHECKING: from synplan.utils.config import ( ValueNetworkConfig, PolicyNetworkConfig, CombinedPolicyConfig, HybridPolicyConfig, ) from synplan.mcts.expansion import ( PolicyNetworkFunction, CombinedPolicyNetworkFunction, ) from synplan.mcts.hybrid_policy import HybridPolicy, MHybridPolicy from synplan.mcts.evaluation import ValueNetworkFunction, EvaluationStrategy REPO_ID = "Laboratoire-De-Chemoinformatique/SynPlanner" logger = logging.getLogger(__name__) def _building_blocks_progress(total: int | None, *, silent: bool): """Create a consistent progress bar for building blocks loading.""" if silent: return None return tqdm( total=total, desc="Building blocks", unit="mol", unit_scale=True, unit_divisor=1000, dynamic_ncols=True, smoothing=0.1, disable=silent, ) def _map_blocks(blocks, worker_fn, *, num_workers: int): """Map blocks through worker function, optionally using a process pool. For `num_workers == 1`, this runs sequentially to avoid process-spawn overhead. """ if num_workers < 1: raise ValueError("num_workers must be >= 1") if num_workers == 1: for block in blocks: yield worker_fn(block) return yield from process_pool_map_stream(blocks, worker_fn, max_workers=num_workers) def _extract_zip(zip_path: Path, out_dir: Path) -> None: """Extract a zip into `out_dir` only if its contents are missing.""" out_dir.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(zip_path, "r") as zf: for name in zf.namelist(): target = out_dir / name if not target.exists(): zf.extract(name, out_dir) def download_selected_files( files_to_get: list[tuple[str, str]], save_to: str | Path = "./tutorials/synplan_data", extract_zips: bool = True, relocate_map: dict[str, str] | None = None, ) -> Path: """ Download specific files from the Hugging Face repo. Parameters ---------- files_to_get : list of (subfolder, filename) Example: [("building_blocks", "building_blocks_em_sa_ln.smi.zip"), ("uspto", "uspto_reaction_rules.pickle"), ("weights", "ranking_policy_network.ckpt")] save_to : path Where to save everything locally. extract_zips : bool If True, extract .zip files to their containing folder. relocate_map : dict[str, str] Optional map { "weights/ranking_policy_network.ckpt": "uspto/weights/ranking_policy_network.ckpt" } to copy/move files after download to match test paths. """ root = Path(save_to).resolve() root.mkdir(parents=True, exist_ok=True) for subfolder, filename in files_to_get: local_path = Path( hf_hub_download( repo_id=REPO_ID, subfolder=subfolder, filename=filename, local_dir=str(root), ) ) if extract_zips and local_path.suffix == ".zip": _extract_zip(local_path, local_path.parent) if relocate_map: for src_rel, dst_rel in relocate_map.items(): src = root / src_rel dst = root / dst_rel dst.parent.mkdir(parents=True, exist_ok=True) if src.exists() and not dst.exists(): shutil.copy2(src, dst) # or shutil.move(src, dst) return root def download_unpack_data(filename, subfolder, save_to="."): if isinstance(save_to, str): save_to = Path(save_to).resolve() save_to.mkdir(exist_ok=True) # Download the zip file from the repository file_path = hf_hub_download( repo_id=REPO_ID, filename=filename, subfolder=subfolder, local_dir=save_to, ) file_path = Path(file_path) if file_path.suffix == ".zip": with zipfile.ZipFile(file_path, "r") as zip_ref: # Extract the single file in the zip zip_ref.extractall(save_to) extracted_file = save_to / zip_ref.namelist()[0] file_path.unlink() return extracted_file else: return file_path def download_all_data(save_to="."): dir_path = snapshot_download(repo_id=REPO_ID, local_dir=save_to) dir_path = Path(dir_path).resolve() for zip_file in dir_path.rglob("*.zip"): with zipfile.ZipFile(zip_file, "r") as zip_ref: # Check each file in the zip for file_name in zip_ref.namelist(): extracted_file_path = zip_file.parent / file_name # Check if the extracted file already exists if not extracted_file_path.exists(): # Extract the file if it does not exist zip_ref.extract(file_name, zip_file.parent) print(f"Extracted {file_name} to {zip_file.parent}") @functools.lru_cache(maxsize=None) def load_reaction_rules(file: str) -> List[Reactor]: """Loads the reaction rules from a pickle file and converts them into a list of Reactor objects if necessary. :param file: The path to the pickle file that stores the reaction rules. :return: A list of reaction rules as Reactor objects. """ with open(file, "rb") as f: reaction_rules = pickle.load(f) if not isinstance(reaction_rules[0][0], Reactor): reaction_rules = [Reactor(x) for x, _ in reaction_rules] return tuple(reaction_rules) @functools.lru_cache(maxsize=None) def load_building_blocks( building_blocks_path: Union[str, Path], standardize: bool = True, silent: bool = True, num_workers: int | None = None, chunksize: int = 1000, *, header: bool = True, delimiter: str = ",", smiles_column: str = "SMILES", ) -> FrozenSet[str]: """Loads building blocks data from a file and returns a frozen set of building blocks. :param building_blocks_path: The path to the file containing the building blocks. :param standardize: Flag if building blocks have to be standardized before loading. Default=True. :param header: For CSV/CSV.GZ files: treat the first row as header. Default=True. :param delimiter: For CSV/CSV.GZ files: delimiter character. Default=",". :param smiles_column: For CSV/CSV.GZ files: header column name containing SMILES. Default="SMILES" (case-insensitive match is supported). :return: The set of building blocks smiles. """ building_blocks_path = Path(building_blocks_path).resolve() suffixes = "".join(building_blocks_path.suffixes).lower() is_csv = suffixes.endswith(".csv") or suffixes.endswith(".csv.gz") suffix = building_blocks_path.suffix.lower() if not is_csv and suffix not in {".smi", ".smiles", ".sdf"}: raise ValueError( f"Unsupported building blocks file extension: '{building_blocks_path.name}'. " "Supported: .smi, .smiles, .sdf, .csv, .csv.gz" ) building_blocks_smiles = set() if standardize: if num_workers is None: num_workers = max(1, os.cpu_count() - 1) if num_workers < 1: raise ValueError("num_workers must be >= 1") if suffix in {".smi", ".smiles"}: total = count_smiles_records(building_blocks_path) if not silent else None step = max(1, chunksize or 1000) progress_iter = _building_blocks_progress(total, silent=silent) for out in _map_blocks( iter_smiles_blocks(building_blocks_path, step), _standardize_smiles_batch, num_workers=num_workers, ): if out: building_blocks_smiles.update(out) if progress_iter is not None: progress_iter.update(len(out)) if progress_iter is not None: progress_iter.close() elif is_csv: step = max(1, chunksize or 1000) progress_iter = _building_blocks_progress(None, silent=silent) blocks = iter_csv_smiles_blocks( building_blocks_path, step, header=header, delimiter=delimiter, smiles_column=smiles_column, ) for out in _map_blocks( blocks, _standardize_smiles_batch, num_workers=num_workers ): if out: building_blocks_smiles.update(out) if progress_iter is not None: progress_iter.update(len(out)) if progress_iter is not None: progress_iter.close() elif suffix == ".sdf": n = count_sdf_records(building_blocks_path) if not silent else None step = max(1, chunksize or 5000) blocks = iter_sdf_text_blocks(building_blocks_path, step) progress = _building_blocks_progress(n, silent=silent) for chunk_out in _map_blocks( blocks, _standardize_sdf_text, num_workers=num_workers ): if chunk_out: building_blocks_smiles.update(chunk_out) if progress is not None: progress.update(len(chunk_out)) if progress is not None: progress.close() else: if suffix in {".smi", ".smiles"}: for smiles in iter_smiles(building_blocks_path): building_blocks_smiles.add(smiles) elif is_csv: for smiles in iter_csv_smiles( building_blocks_path, header=header, delimiter=delimiter, smiles_column=smiles_column, ): building_blocks_smiles.add(smiles) elif suffix == ".sdf": with SDFRead(str(building_blocks_path)) as sdf: for mol in sdf: try: building_blocks_smiles.add(str(mol)) except Exception: pass return frozenset(building_blocks_smiles) def load_value_net( model_class: ValueNetwork, value_network_path: Union[str, Path] ) -> ValueNetwork: """Loads the value network. :param value_network_path: The path to the file storing value network weights. :param model_class: The model class to be loaded. :return: The loaded value network. """ map_location = device("cpu") return model_class.load_from_checkpoint(value_network_path, map_location) def load_policy_net( model_class: PolicyNetwork, policy_network_path: Union[str, Path] ) -> PolicyNetwork: """Loads the policy network. :param policy_network_path: The path to the file storing policy network weights. :param model_class: The model class to be loaded. :return: The loaded policy network. """ map_location = device("cpu") return model_class.load_from_checkpoint( policy_network_path, map_location, batch_size=1 ) def load_policy_function( policy_config: Union["PolicyNetworkConfig", dict, None] = None, weights_path: str = None, **config_kwargs, ) -> "PolicyNetworkFunction": """Factory function to create PolicyNetworkFunction with flexible configuration. Priority order: policy_config > weights_path + kwargs > defaults :param policy_config: PolicyNetworkConfig object or dict with config parameters :param weights_path: Direct path to weights file (shortcut for simple cases) :param config_kwargs: Additional config parameters to override defaults :return: PolicyNetworkFunction ready for use in tree search Examples: >>> # Using config object >>> config = PolicyNetworkConfig(weights_path="path.ckpt", top_rules=50) >>> policy_fn = load_policy_function(policy_config=config) >>> >>> # Using direct path (simplest) >>> policy_fn = load_policy_function(weights_path="path.ckpt") >>> >>> # Using path with overrides >>> policy_fn = load_policy_function(weights_path="path.ckpt", top_rules=100) """ from synplan.mcts.expansion_old import PolicyNetworkFunction from synplan.utils.config import PolicyNetworkConfig # Priority 1: Use provided config if policy_config is not None: if isinstance(policy_config, dict): policy_config = PolicyNetworkConfig.from_dict(policy_config) return PolicyNetworkFunction(policy_config=policy_config) # Priority 2: Create config from weights_path and kwargs if weights_path is not None: policy_config = PolicyNetworkConfig(weights_path=weights_path) return PolicyNetworkFunction(policy_config=policy_config) raise ValueError("Must provide either policy_config or weights_path") def load_combined_policy_function( combined_config: Union["CombinedPolicyConfig", dict] = None, filtering_config: Union["PolicyNetworkConfig", dict, str] = None, ranking_config: Union["PolicyNetworkConfig", dict, str] = None, filtering_weights_path: str = None, ranking_weights_path: str = None, top_rules: int = 50, rule_prob_threshold: float = 0.0, ranking_weight: float = 1.0, temperature: float = 1.0, ) -> "CombinedPolicyNetworkFunction": """Factory function to create CombinedPolicyNetworkFunction with flexible configuration. Combines filtering and ranking policies by weighted addition of logits: combined_logits = filtering_logits + ranking_weight * ranking_logits combined_probs = softmax(combined_logits / temperature) The filtering policy provides applicability scores (trained on multi-label applicability). The ranking policy provides feasibility scores (trained on actual reactions). :param combined_config: CombinedPolicyConfig or dict with all parameters. :param filtering_config: PolicyNetworkConfig or dict for filtering policy. :param ranking_config: PolicyNetworkConfig or dict for ranking policy. :param filtering_weights_path: Direct path to filtering weights (shortcut). :param ranking_weights_path: Direct path to ranking weights (shortcut). :param top_rules: Number of top rules to return. :param rule_prob_threshold: Minimum probability threshold for returning a rule. :param ranking_weight: Weight for ranking logits (default 1.0). Values > 1.0 give more weight to ranking (feasibility). :param temperature: Temperature for softmax (default 1.0). Values > 1.0 produce softer distributions (more exploration). :return: CombinedPolicyNetworkFunction ready for use in tree search. Examples: >>> # Using CombinedPolicyConfig >>> config = CombinedPolicyConfig( ... filtering_weights_path="filtering.ckpt", ... ranking_weights_path="ranking.ckpt", ... ) >>> combined = load_combined_policy_function(combined_config=config) >>> >>> # Using config objects >>> combined = load_combined_policy_function( ... filtering_config={"weights_path": "filtering.ckpt", "policy_type": "filtering"}, ... ranking_config={"weights_path": "ranking.ckpt", "policy_type": "ranking"}, ... ) >>> >>> # Using direct paths (simplest) >>> combined = load_combined_policy_function( ... filtering_weights_path="filtering.ckpt", ... ranking_weights_path="ranking.ckpt", ... ) """ from synplan.mcts.expansion_old import CombinedPolicyNetworkFunction from synplan.utils.config import PolicyNetworkConfig, CombinedPolicyConfig # Priority 1: Use CombinedPolicyConfig if combined_config is not None: if isinstance(combined_config, dict): combined_config = CombinedPolicyConfig.from_dict(combined_config) filtering_weights_path = combined_config.filtering_weights_path ranking_weights_path = combined_config.ranking_weights_path top_rules = combined_config.top_rules rule_prob_threshold = combined_config.rule_prob_threshold ranking_weight = combined_config.ranking_weight temperature = combined_config.temperature filtering_config = PolicyNetworkConfig( weights_path=filtering_weights_path, policy_type="filtering" ) ranking_config = PolicyNetworkConfig( weights_path=ranking_weights_path, policy_type="ranking" ) return CombinedPolicyNetworkFunction( filtering_config=filtering_config, ranking_config=ranking_config, top_rules=top_rules, rule_prob_threshold=rule_prob_threshold, ranking_weight=ranking_weight, temperature=temperature, ) # Build filtering config if filtering_config is not None: if isinstance(filtering_config, str): filtering_config = PolicyNetworkConfig( weights_path=filtering_config, policy_type="filtering" ) elif isinstance(filtering_config, dict): filtering_config.setdefault("policy_type", "filtering") filtering_config = PolicyNetworkConfig.from_dict(filtering_config) elif filtering_weights_path is not None: filtering_config = PolicyNetworkConfig( weights_path=filtering_weights_path, policy_type="filtering" ) else: raise ValueError( "Must provide either filtering_config or filtering_weights_path" ) # Build ranking config if ranking_config is not None: if isinstance(ranking_config, str): ranking_config = PolicyNetworkConfig( weights_path=ranking_config, policy_type="ranking" ) elif isinstance(ranking_config, dict): ranking_config.setdefault("policy_type", "ranking") ranking_config = PolicyNetworkConfig.from_dict(ranking_config) elif ranking_weights_path is not None: ranking_config = PolicyNetworkConfig( weights_path=ranking_weights_path, policy_type="ranking" ) else: raise ValueError("Must provide either ranking_config or ranking_weights_path") return CombinedPolicyNetworkFunction( filtering_config=filtering_config, ranking_config=ranking_config, top_rules=top_rules, rule_prob_threshold=rule_prob_threshold, ranking_weight=ranking_weight, temperature=temperature, ) def load_hybrid_policy_function( hybrid_config: Union["HybridPolicyConfig", dict] = None, filtering_weights_path: str = None, ranking_weights_path: str = None, filtering_rank_weights: List[float] = None, ranking_rank_weights: List[float] = None, probability_from_score_temperature: float = None, hybrid_policy_type: str = None, **kwargs, ) -> Union["HybridPolicy", "MHybridPolicy"]: """Factory function to create HybridPolicy with flexible configuration. :param hybrid_config: HybridPolicyConfig or dict with all parameters. :param filtering_weights_path: Direct path to filtering weights (shortcut). :param ranking_weights_path: Direct path to ranking weights (shortcut). :param filtering_rank_weights: Rank weights for filtering model. Required for ``hybrid_policy_type="rank_weighted"``. :param ranking_rank_weights: Rank weights for ranking model. Required for ``hybrid_policy_type="rank_weighted"``. :param probability_from_score_temperature: Temperature for converting scores to probabilities. :param hybrid_policy_type: Hybrid aggregation mode selector. Supported aliases: - ``rank_weighted`` / ``HybridPolicy`` -> HybridPolicy - ``masked`` / ``MHybridPolicy`` -> MHybridPolicy :return: Hybrid policy ready for use in tree search. """ from synplan.mcts.hybrid_policy import HybridPolicy, MHybridPolicy from synplan.utils.config import HybridPolicyConfig def _build_policy(config: HybridPolicyConfig) -> Union["HybridPolicy", "MHybridPolicy"]: if config.hybrid_policy_type == "masked": return MHybridPolicy(config=config) return HybridPolicy(config=config) if hybrid_config is not None: if isinstance(hybrid_config, (HybridPolicy, MHybridPolicy)): return hybrid_config if isinstance(hybrid_config, dict): hybrid_config = HybridPolicyConfig.from_dict(hybrid_config) return _build_policy(hybrid_config) if filtering_weights_path is not None and ranking_weights_path is not None: config_kwargs = { "filtering_weights_path": filtering_weights_path, "ranking_weights_path": ranking_weights_path, } if filtering_rank_weights is not None: config_kwargs["filtering_rank_weights"] = filtering_rank_weights if ranking_rank_weights is not None: config_kwargs["ranking_rank_weights"] = ranking_rank_weights if probability_from_score_temperature is not None: config_kwargs["probability_from_score_temperature"] = ( probability_from_score_temperature ) if hybrid_policy_type is not None: config_kwargs["hybrid_policy_type"] = hybrid_policy_type config_kwargs.update(kwargs) hybrid_config = HybridPolicyConfig(**config_kwargs) return _build_policy(hybrid_config) raise ValueError( "Must provide either hybrid_config or both filtering_weights_path and ranking_weights_path" ) def load_value_network( value_config: Union["ValueNetworkConfig", dict, None] = None, weights_path: str = None, **config_kwargs, ) -> "ValueNetworkFunction": """Factory function to create ValueNetworkFunction with flexible configuration. Priority order: value_config > weights_path + kwargs > defaults :param value_config: ValueNetworkConfig object or dict with config parameters :param weights_path: Direct path to weights file (shortcut for simple cases) :param config_kwargs: Additional config parameters to override defaults :return: ValueNetworkFunction ready for use in tree search Examples: >>> # Using config object >>> config = ValueNetworkConfig(weights_path="path.ckpt") >>> value_fn = load_value_network(value_config=config) >>> >>> # Using direct path (simplest) >>> value_fn = load_value_network(weights_path="path.ckpt") """ from synplan.mcts.evaluation import ValueNetworkFunction from synplan.utils.config import ValueNetworkConfig # Priority 1: Use provided config if value_config is not None: if isinstance(value_config, dict): value_config = ValueNetworkConfig.from_dict(value_config) # ValueNetworkFunction only takes weights_path return ValueNetworkFunction(weights_path=value_config.weights_path) # Priority 2: Use direct weights_path if weights_path is not None: return ValueNetworkFunction(weights_path=weights_path) raise ValueError("Must provide either value_config or weights_path") def load_evaluation_function(eval_config) -> "EvaluationStrategy": """Create evaluation strategy from configuration. This is the central factory function that creates the appropriate evaluation strategy based on the config type. The config contains all necessary dependencies. :param eval_config: Evaluation configuration object (self-contained). Can be one of: - RolloutEvaluationConfig - ValueNetworkEvaluationConfig - RDKitEvaluationConfig - PolicyEvaluationConfig - RandomEvaluationConfig :return: Evaluation strategy ready to use in tree search. Examples: >>> # Rollout evaluation >>> config = RolloutEvaluationConfig( ... policy_network=policy, ... reaction_rules=rules, ... building_blocks=bbs, ... max_depth=9 ... ) >>> evaluator = load_evaluation_function(config) >>> >>> # Value network evaluation >>> config = ValueNetworkEvaluationConfig(weights_path="path.ckpt") >>> evaluator = load_evaluation_function(config) """ from synplan.mcts.evaluation import ( RolloutEvaluationStrategy, ValueNetworkEvaluationStrategy, RDKitEvaluationStrategy, PolicyEvaluationStrategy, RandomEvaluationStrategy, ) from synplan.utils.config import ( RolloutEvaluationConfig, ValueNetworkEvaluationConfig, RDKitEvaluationConfig, PolicyEvaluationConfig, RandomEvaluationConfig, ) logger.debug(f"create_evaluator config_type={type(eval_config).__name__}") if isinstance(eval_config, RolloutEvaluationConfig): return RolloutEvaluationStrategy( policy_network=eval_config.policy_network, reaction_rules=eval_config.reaction_rules, building_blocks=eval_config.building_blocks, min_mol_size=eval_config.min_mol_size, max_depth=eval_config.max_depth, normalize=eval_config.normalize, stochastic=eval_config.stochastic, ) elif isinstance(eval_config, ValueNetworkEvaluationConfig): # Load value network from path in config value_net = load_value_network(weights_path=eval_config.weights_path) return ValueNetworkEvaluationStrategy( value_network=value_net, normalize=eval_config.normalize, ) elif isinstance(eval_config, RDKitEvaluationConfig): return RDKitEvaluationStrategy( score_function=eval_config.score_function, normalize=eval_config.normalize, ) elif isinstance(eval_config, PolicyEvaluationConfig): return PolicyEvaluationStrategy(normalize=eval_config.normalize) elif isinstance(eval_config, RandomEvaluationConfig): return RandomEvaluationStrategy(normalize=eval_config.normalize) else: raise ValueError( f"Unknown evaluation config type: {type(eval_config)}. " f"Expected one of: RolloutEvaluationConfig, ValueNetworkEvaluationConfig, " f"RDKitEvaluationConfig, PolicyEvaluationConfig, RandomEvaluationConfig." )