File size: 5,001 Bytes
72a3513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Module containing functions for loading reaction rules, building blocks and
retrosynthetic models."""

import functools
import pickle
import zipfile
from pathlib import Path
from typing import List, Set, Union

from CGRtools.reactor.reactor import Reactor
from torch import device
from huggingface_hub import hf_hub_download, snapshot_download
from tqdm import tqdm

from synplan.ml.networks.policy import PolicyNetwork
from synplan.ml.networks.value import ValueNetwork
from synplan.utils.files import MoleculeReader


def download_unpack_data(filename, subfolder, save_to="."):
    if isinstance(save_to, str):
        save_to = Path(save_to).resolve()
        save_to.mkdir(exist_ok=True)

    # Download the zip file from the repository
    file_path = hf_hub_download(
        repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
        filename=filename,
        subfolder=subfolder,
        local_dir=save_to,
    )
    file_path = Path(file_path)

    if file_path.suffix == ".zip":
        with zipfile.ZipFile(file_path, "r") as zip_ref:
            # Extract the single file in the zip
            zip_ref.extractall(save_to)
            extracted_file = save_to / zip_ref.namelist()[0]

        file_path.unlink()

        return extracted_file
    else:
        return file_path


def download_all_data(save_to="."):
    dir_path = snapshot_download(
        repo_id="Laboratoire-De-Chemoinformatique/SynPlanner", local_dir=save_to
    )
    dir_path = Path(dir_path).resolve()
    for zip_file in dir_path.rglob("*.zip"):
        with zipfile.ZipFile(zip_file, "r") as zip_ref:
            # Check each file in the zip
            for file_name in zip_ref.namelist():
                extracted_file_path = zip_file.parent / file_name

                # Check if the extracted file already exists
                if not extracted_file_path.exists():
                    # Extract the file if it does not exist
                    zip_ref.extract(file_name, zip_file.parent)
                    print(f"Extracted {file_name} to {zip_file.parent}")


@functools.lru_cache(maxsize=None)
def load_reaction_rules(file: str) -> List[Reactor]:
    """Loads the reaction rules from a pickle file and converts them into a list of
    Reactor objects if necessary.

    :param file: The path to the pickle file that stores the reaction rules.
    :return: A list of reaction rules as Reactor objects.
    """

    with open(file, "rb") as f:
        reaction_rules = pickle.load(f)

    if not isinstance(reaction_rules[0][0], Reactor):
        reaction_rules = [Reactor(x) for x, _ in reaction_rules]

    return reaction_rules


@functools.lru_cache(maxsize=None)
def load_building_blocks(
    building_blocks_path: Union[str, Path], standardize: bool = True
) -> Set[str]:
    """Loads building blocks data from a file and returns a frozen set of building
    blocks.

    :param building_blocks_path: The path to the file containing the building blocks.
    :param standardize: Flag if building blocks have to be standardized before loading. Default=True.
    :return: The set of building blocks smiles.
    """

    building_blocks_path = Path(building_blocks_path).resolve()
    assert (
        building_blocks_path.suffix == ".smi"
        or building_blocks_path.suffix == ".smiles"
    )

    building_blocks_smiles = set()
    if standardize:
        with MoleculeReader(building_blocks_path) as molecules:
            for mol in tqdm(
                molecules,
                desc="Number of building blocks processed: ",
                bar_format="{desc}{n} [{elapsed}]",
            ):
                try:
                    mol.canonicalize()
                    mol.clean_stereo()
                    building_blocks_smiles.add(str(mol))
                except:  # mol.canonicalize() / InvalidAromaticRing
                    pass
    else:
        with open(building_blocks_path, "r") as inp:
            for line in inp:
                smiles = line.strip().split()[0]
                building_blocks_smiles.add(smiles)

    return building_blocks_smiles


def load_value_net(
    model_class: ValueNetwork, value_network_path: Union[str, Path]
) -> ValueNetwork:
    """Loads the value network.

    :param value_network_path: The path to the file storing value network weights.
    :param model_class: The model class to be loaded.
    :return: The loaded value network.
    """

    map_location = device("cpu")
    return model_class.load_from_checkpoint(value_network_path, map_location)


def load_policy_net(
    model_class: PolicyNetwork, policy_network_path: Union[str, Path]
) -> PolicyNetwork:
    """Loads the policy network.

    :param policy_network_path: The path to the file storing policy network weights.
    :param model_class: The model class to be loaded.
    :return: The loaded policy network.
    """

    map_location = device("cpu")
    return model_class.load_from_checkpoint(
        policy_network_path, map_location, batch_size=1
    )