kabudadada
Add esm folder and minimal app
e76b79a
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
from dataclasses import dataclass
import numpy as np
from rich.live import Live
from rich.table import Table
from language.folding_callbacks import FoldingCallback
from language.program import ProgramNode
@dataclass
class MetropolisHastingsState:
program: ProgramNode
temperature: float
annealing_rate: float
num_steps: int
candidate_energy: float
candidate_energy_term_fn_values: list
current_energy: float
current_energy_term_fn_values: list
best_energy: float
best_energy_term_fn_values: list
def metropolis_hastings_step(
state: MetropolisHastingsState,
folding_callback: FoldingCallback,
verbose: bool = False,
) -> MetropolisHastingsState:
temperature = state.temperature * state.annealing_rate
candidate: ProgramNode = deepcopy(state.program)
candidate.mutate()
sequence, residue_indices = candidate.get_sequence_and_set_residue_index_ranges()
folding_output = folding_callback.fold(sequence, residue_indices)
energy_term_fns = candidate.get_energy_term_functions()
candidate_energy_term_fn_values = [
(name, weight, energy_fn(folding_output)) for name, weight, energy_fn in energy_term_fns
]
# TODO(scandido): Log these.
candidate_energy: float = sum(
[weight * value for _, weight, value in candidate_energy_term_fn_values]
)
accept_candidate = False
if state.current_energy is None:
accept_candidate = True
else:
# NOTE(scandido): We are minimizing the function here so instead of
# candidate - current we do -1 * (candidate - current) = -candidate + current.
energy_differential: float = -candidate_energy + state.current_energy
accept_probability: float = np.clip(
# NOTE(scandido): We approximate the ratio of transition probabilities from
# current to candidate vs. candidate to current to be equal, which is
# approximately correct.
np.exp(energy_differential / temperature),
a_min=None,
a_max=1.0,
)
accept_candidate: bool = np.random.uniform() < accept_probability
if accept_candidate and verbose:
print(f"Accepted {sequence} with energy {candidate_energy:.2f}.")
best = (state.best_energy is None) or candidate_energy < state.best_energy
return MetropolisHastingsState(
program=candidate if accept_candidate else state.program,
temperature=temperature,
annealing_rate=state.annealing_rate,
num_steps=state.num_steps + 1,
candidate_energy=candidate_energy,
candidate_energy_term_fn_values=candidate_energy_term_fn_values,
current_energy=candidate_energy if accept_candidate else state.current_energy,
current_energy_term_fn_values=candidate_energy_term_fn_values
if accept_candidate
else state.current_energy_term_fn_values,
best_energy=candidate_energy if best else state.best_energy,
best_energy_term_fn_values=candidate_energy_term_fn_values
if best
else state.best_energy_term_fn_values,
)
def run_simulated_annealing(
program: ProgramNode,
initial_temperature: float,
annealing_rate: float,
total_num_steps: int,
folding_callback: FoldingCallback,
display_progress: bool = True,
progress_verbose_print: bool = False,
) -> ProgramNode:
# TODO(scandido): Track accept rate.
state = MetropolisHastingsState(
program=program,
temperature=initial_temperature,
annealing_rate=annealing_rate,
num_steps=0,
candidate_energy=None,
candidate_energy_term_fn_values=None,
current_energy=None,
current_energy_term_fn_values=None,
best_energy=None,
best_energy_term_fn_values=None,
)
def _generate_table(state):
table = Table()
table.add_column("Energy name")
table.add_column("Weight")
table.add_column("Candidate Value")
table.add_column("Current Value")
table.add_column("Best Value")
if state.current_energy_term_fn_values is None:
return table
for (name, weight, candidate_value), (_, _, current_value), (_, _, best_value) in zip(
state.candidate_energy_term_fn_values,
state.current_energy_term_fn_values,
state.best_energy_term_fn_values,
):
table.add_row(
name,
f"{weight:.2f}",
f"{candidate_value:.2f}",
f"{current_value:.2f}",
f"{best_value:.2f}",
)
table.add_row(
"Energy",
"",
f"{state.candidate_energy:.2f}",
f"{state.current_energy:.2f}",
f"{state.best_energy:.2f}",
)
table.add_row("Iterations", "", f"{state.num_steps} / {total_num_steps}")
return table
with Live() as live:
for _ in range(1, total_num_steps + 1):
state = metropolis_hastings_step(
state,
folding_callback,
verbose=progress_verbose_print,
)
if display_progress:
live.update(_generate_table(state))
return state.program