Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from functools import partial | |
| from typing import Callable, List, Optional, Tuple | |
| import numpy as np | |
| from language.energy import EnergyTerm | |
| from language.folding_callbacks import FoldingResult | |
| from language.sequence import SequenceSegmentFactory | |
| MULTIMER_RESIDUE_INDEX_SKIP_LENGTH: int = 1000 | |
| class ProgramNode: | |
| def __init__( | |
| self, | |
| children: List["ProgramNode"] = None, | |
| sequence_segment: SequenceSegmentFactory = None, | |
| children_are_different_chains: bool = False, | |
| energy_function_terms: List[EnergyTerm] = [], | |
| energy_function_weights: Optional[List[float]] = None, | |
| ) -> None: | |
| self.children: Optional[List["ProgramNode"]] = children | |
| self.sequence_segment: SequenceSegmentFactory = sequence_segment | |
| self.children_are_different_chains: bool = children_are_different_chains | |
| self.energy_function_terms: List[energy_function_terms] = energy_function_terms | |
| self.energy_function_weights: List[ | |
| float | |
| ] = energy_function_weights if energy_function_weights else [ | |
| 1.0 for _ in self.energy_function_terms | |
| ] | |
| if self.energy_function_weights: | |
| assert len(self.energy_function_terms) == len( | |
| self.energy_function_weights | |
| ), "One must have the same number of energy function terms and weights on a node." | |
| self.residue_index_range: Optional[Tuple[int, int]] = None | |
| def get_sequence_and_set_residue_index_ranges( | |
| self, residue_index_offset: int = 1 | |
| ) -> Tuple[str, List[int]]: | |
| if self.is_leaf_node(): | |
| sequence = self.sequence_segment.get() | |
| self.residue_index_range = ( | |
| residue_index_offset, | |
| residue_index_offset + len(sequence), | |
| ) | |
| return sequence, list(range(*self.residue_index_range)) | |
| offset: int = residue_index_offset | |
| sequence = "" | |
| residue_indices = [] | |
| for child in self.children: | |
| ( | |
| sequence_segment, | |
| residue_indices_segment, | |
| ) = child.get_sequence_and_set_residue_index_ranges( | |
| residue_index_offset=offset | |
| ) | |
| sequence += sequence_segment | |
| residue_indices += residue_indices_segment | |
| offset = residue_indices[-1] + 1 | |
| if self.children_are_different_chains: | |
| offset += MULTIMER_RESIDUE_INDEX_SKIP_LENGTH | |
| self.residue_index_range = (residue_indices[0], residue_indices[-1] + 1) | |
| return sequence, residue_indices | |
| def get_residue_index_range(self) -> Tuple[int, int]: | |
| assert ( | |
| self.residue_index_range | |
| ), "Must call get_sequence_and_set_residue_index_ranges() first." | |
| return self.residue_index_range | |
| def get_children(self) -> List["ProgramNode"]: | |
| return self.children | |
| def is_leaf_node(self) -> bool: | |
| return self.children is None | |
| def get_energy_term_functions( | |
| self, name_prefix: str = "" | |
| ) -> List[Tuple[str, float, Callable[[FoldingResult], float]]]: | |
| name_prefix = name_prefix if name_prefix else "root" | |
| terms = [ | |
| ( | |
| f"{name_prefix}:{type(term).__name__}", | |
| weight, | |
| partial(term.compute, self), | |
| ) | |
| for weight, term in zip( | |
| self.energy_function_weights, self.energy_function_terms | |
| ) | |
| ] | |
| if self.is_leaf_node(): | |
| return terms | |
| for i, child in enumerate(self.children): | |
| terms += child.get_energy_term_functions( | |
| name_prefix=name_prefix + f".n{i+1}" | |
| ) | |
| return terms | |
| def mutate(self) -> None: | |
| if self.is_leaf_node(): | |
| return self.sequence_segment.mutate() | |
| weights = np.array( | |
| [float(child.num_mutation_candidates()) for child in self.children] | |
| ) | |
| assert ( | |
| weights.sum() > 0 | |
| ), "Some mutations should be possible if mutate() was called." | |
| child_to_mutate = np.random.choice(self.children, p=weights / weights.sum()) | |
| child_to_mutate.mutate() | |
| def num_mutation_candidates(self) -> int: | |
| if self.is_leaf_node(): | |
| return self.sequence_segment.num_mutation_candidates() | |
| return sum([child.num_mutation_candidates() for child in self.children]) | |