Spaces:
Running on Zero
Running on Zero
| # Copyright 2021 AlQuraishi Laboratory | |
| # Copyright 2021 DeepMind Technologies Limited | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| from typing import Mapping, Tuple, List, Optional, Dict, Sequence | |
| import ml_collections | |
| import numpy as np | |
| import torch | |
| from openfold.data import input_pipeline | |
| FeatureDict = Mapping[str, np.ndarray] | |
| TensorDict = Dict[str, torch.Tensor] | |
| def np_to_tensor_dict( | |
| np_example: Mapping[str, np.ndarray], | |
| features: Sequence[str], | |
| ) -> TensorDict: | |
| """Creates dict of tensors from a dict of NumPy arrays. | |
| Args: | |
| np_example: A dict of NumPy feature arrays. | |
| features: A list of strings of feature names to be returned in the dataset. | |
| Returns: | |
| A dictionary of features mapping feature names to features. Only the given | |
| features are returned, all other ones are filtered out. | |
| """ | |
| tensor_dict = { | |
| k: torch.tensor(v) for k, v in np_example.items() if k in features | |
| } | |
| return tensor_dict | |
| def make_data_config( | |
| config: ml_collections.ConfigDict, | |
| mode: str, | |
| num_res: int, | |
| ) -> Tuple[ml_collections.ConfigDict, List[str]]: | |
| cfg = copy.deepcopy(config) | |
| mode_cfg = cfg[mode] | |
| with cfg.unlocked(): | |
| if mode_cfg.crop_size is None: | |
| mode_cfg.crop_size = num_res | |
| feature_names = cfg.common.unsupervised_features | |
| if cfg.common.use_templates: | |
| feature_names += cfg.common.template_features | |
| if cfg[mode].supervised: | |
| feature_names += cfg.supervised.supervised_features | |
| return cfg, feature_names | |
| def np_example_to_features( | |
| np_example: FeatureDict, | |
| config: ml_collections.ConfigDict, | |
| mode: str, | |
| ): | |
| np_example = dict(np_example) | |
| num_res = int(np_example["seq_length"][0]) | |
| cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) | |
| if "deletion_matrix_int" in np_example: | |
| np_example["deletion_matrix"] = np_example.pop( | |
| "deletion_matrix_int" | |
| ).astype(np.float32) | |
| tensor_dict = np_to_tensor_dict( | |
| np_example=np_example, features=feature_names | |
| ) | |
| with torch.no_grad(): | |
| features = input_pipeline.process_tensors_from_config( | |
| tensor_dict, | |
| cfg.common, | |
| cfg[mode], | |
| ) | |
| return {k: v for k, v in features.items()} | |
| class FeaturePipeline: | |
| def __init__( | |
| self, | |
| config: ml_collections.ConfigDict, | |
| ): | |
| self.config = config | |
| def process_features( | |
| self, | |
| raw_features: FeatureDict, | |
| mode: str = "train", | |
| ) -> FeatureDict: | |
| return np_example_to_features( | |
| np_example=raw_features, | |
| config=self.config, | |
| mode=mode, | |
| ) | |