kabudadada
Add esm folder and minimal app
e76b79a
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
from typing import Callable, List, Optional, Tuple
import numpy as np
from language.energy import EnergyTerm
from language.folding_callbacks import FoldingResult
from language.sequence import SequenceSegmentFactory
MULTIMER_RESIDUE_INDEX_SKIP_LENGTH: int = 1000
class ProgramNode:
def __init__(
self,
children: List["ProgramNode"] = None,
sequence_segment: SequenceSegmentFactory = None,
children_are_different_chains: bool = False,
energy_function_terms: List[EnergyTerm] = [],
energy_function_weights: Optional[List[float]] = None,
) -> None:
self.children: Optional[List["ProgramNode"]] = children
self.sequence_segment: SequenceSegmentFactory = sequence_segment
self.children_are_different_chains: bool = children_are_different_chains
self.energy_function_terms: List[energy_function_terms] = energy_function_terms
self.energy_function_weights: List[
float
] = energy_function_weights if energy_function_weights else [
1.0 for _ in self.energy_function_terms
]
if self.energy_function_weights:
assert len(self.energy_function_terms) == len(
self.energy_function_weights
), "One must have the same number of energy function terms and weights on a node."
self.residue_index_range: Optional[Tuple[int, int]] = None
def get_sequence_and_set_residue_index_ranges(
self, residue_index_offset: int = 1
) -> Tuple[str, List[int]]:
if self.is_leaf_node():
sequence = self.sequence_segment.get()
self.residue_index_range = (
residue_index_offset,
residue_index_offset + len(sequence),
)
return sequence, list(range(*self.residue_index_range))
offset: int = residue_index_offset
sequence = ""
residue_indices = []
for child in self.children:
(
sequence_segment,
residue_indices_segment,
) = child.get_sequence_and_set_residue_index_ranges(
residue_index_offset=offset
)
sequence += sequence_segment
residue_indices += residue_indices_segment
offset = residue_indices[-1] + 1
if self.children_are_different_chains:
offset += MULTIMER_RESIDUE_INDEX_SKIP_LENGTH
self.residue_index_range = (residue_indices[0], residue_indices[-1] + 1)
return sequence, residue_indices
def get_residue_index_range(self) -> Tuple[int, int]:
assert (
self.residue_index_range
), "Must call get_sequence_and_set_residue_index_ranges() first."
return self.residue_index_range
def get_children(self) -> List["ProgramNode"]:
return self.children
def is_leaf_node(self) -> bool:
return self.children is None
def get_energy_term_functions(
self, name_prefix: str = ""
) -> List[Tuple[str, float, Callable[[FoldingResult], float]]]:
name_prefix = name_prefix if name_prefix else "root"
terms = [
(
f"{name_prefix}:{type(term).__name__}",
weight,
partial(term.compute, self),
)
for weight, term in zip(
self.energy_function_weights, self.energy_function_terms
)
]
if self.is_leaf_node():
return terms
for i, child in enumerate(self.children):
terms += child.get_energy_term_functions(
name_prefix=name_prefix + f".n{i+1}"
)
return terms
def mutate(self) -> None:
if self.is_leaf_node():
return self.sequence_segment.mutate()
weights = np.array(
[float(child.num_mutation_candidates()) for child in self.children]
)
assert (
weights.sum() > 0
), "Some mutations should be possible if mutate() was called."
child_to_mutate = np.random.choice(self.children, p=weights / weights.sum())
child_to_mutate.mutate()
def num_mutation_candidates(self) -> int:
if self.is_leaf_node():
return self.sequence_segment.num_mutation_candidates()
return sum([child.num_mutation_candidates() for child in self.children])