Spaces:
Sleeping
Sleeping
File size: 4,604 Bytes
e76b79a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
from typing import Callable, List, Optional, Tuple
import numpy as np
from language.energy import EnergyTerm
from language.folding_callbacks import FoldingResult
from language.sequence import SequenceSegmentFactory
MULTIMER_RESIDUE_INDEX_SKIP_LENGTH: int = 1000
class ProgramNode:
def __init__(
self,
children: List["ProgramNode"] = None,
sequence_segment: SequenceSegmentFactory = None,
children_are_different_chains: bool = False,
energy_function_terms: List[EnergyTerm] = [],
energy_function_weights: Optional[List[float]] = None,
) -> None:
self.children: Optional[List["ProgramNode"]] = children
self.sequence_segment: SequenceSegmentFactory = sequence_segment
self.children_are_different_chains: bool = children_are_different_chains
self.energy_function_terms: List[energy_function_terms] = energy_function_terms
self.energy_function_weights: List[
float
] = energy_function_weights if energy_function_weights else [
1.0 for _ in self.energy_function_terms
]
if self.energy_function_weights:
assert len(self.energy_function_terms) == len(
self.energy_function_weights
), "One must have the same number of energy function terms and weights on a node."
self.residue_index_range: Optional[Tuple[int, int]] = None
def get_sequence_and_set_residue_index_ranges(
self, residue_index_offset: int = 1
) -> Tuple[str, List[int]]:
if self.is_leaf_node():
sequence = self.sequence_segment.get()
self.residue_index_range = (
residue_index_offset,
residue_index_offset + len(sequence),
)
return sequence, list(range(*self.residue_index_range))
offset: int = residue_index_offset
sequence = ""
residue_indices = []
for child in self.children:
(
sequence_segment,
residue_indices_segment,
) = child.get_sequence_and_set_residue_index_ranges(
residue_index_offset=offset
)
sequence += sequence_segment
residue_indices += residue_indices_segment
offset = residue_indices[-1] + 1
if self.children_are_different_chains:
offset += MULTIMER_RESIDUE_INDEX_SKIP_LENGTH
self.residue_index_range = (residue_indices[0], residue_indices[-1] + 1)
return sequence, residue_indices
def get_residue_index_range(self) -> Tuple[int, int]:
assert (
self.residue_index_range
), "Must call get_sequence_and_set_residue_index_ranges() first."
return self.residue_index_range
def get_children(self) -> List["ProgramNode"]:
return self.children
def is_leaf_node(self) -> bool:
return self.children is None
def get_energy_term_functions(
self, name_prefix: str = ""
) -> List[Tuple[str, float, Callable[[FoldingResult], float]]]:
name_prefix = name_prefix if name_prefix else "root"
terms = [
(
f"{name_prefix}:{type(term).__name__}",
weight,
partial(term.compute, self),
)
for weight, term in zip(
self.energy_function_weights, self.energy_function_terms
)
]
if self.is_leaf_node():
return terms
for i, child in enumerate(self.children):
terms += child.get_energy_term_functions(
name_prefix=name_prefix + f".n{i+1}"
)
return terms
def mutate(self) -> None:
if self.is_leaf_node():
return self.sequence_segment.mutate()
weights = np.array(
[float(child.num_mutation_candidates()) for child in self.children]
)
assert (
weights.sum() > 0
), "Some mutations should be possible if mutate() was called."
child_to_mutate = np.random.choice(self.children, p=weights / weights.sum())
child_to_mutate.mutate()
def num_mutation_candidates(self) -> int:
if self.is_leaf_node():
return self.sequence_segment.num_mutation_candidates()
return sum([child.num_mutation_candidates() for child in self.children])
|