# Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from copy import deepcopy from dataclasses import dataclass import numpy as np from rich.live import Live from rich.table import Table from language.folding_callbacks import FoldingCallback from language.program import ProgramNode @dataclass class MetropolisHastingsState: program: ProgramNode temperature: float annealing_rate: float num_steps: int candidate_energy: float candidate_energy_term_fn_values: list current_energy: float current_energy_term_fn_values: list best_energy: float best_energy_term_fn_values: list def metropolis_hastings_step( state: MetropolisHastingsState, folding_callback: FoldingCallback, verbose: bool = False, ) -> MetropolisHastingsState: temperature = state.temperature * state.annealing_rate candidate: ProgramNode = deepcopy(state.program) candidate.mutate() sequence, residue_indices = candidate.get_sequence_and_set_residue_index_ranges() folding_output = folding_callback.fold(sequence, residue_indices) energy_term_fns = candidate.get_energy_term_functions() candidate_energy_term_fn_values = [ (name, weight, energy_fn(folding_output)) for name, weight, energy_fn in energy_term_fns ] # TODO(scandido): Log these. candidate_energy: float = sum( [weight * value for _, weight, value in candidate_energy_term_fn_values] ) accept_candidate = False if state.current_energy is None: accept_candidate = True else: # NOTE(scandido): We are minimizing the function here so instead of # candidate - current we do -1 * (candidate - current) = -candidate + current. energy_differential: float = -candidate_energy + state.current_energy accept_probability: float = np.clip( # NOTE(scandido): We approximate the ratio of transition probabilities from # current to candidate vs. candidate to current to be equal, which is # approximately correct. np.exp(energy_differential / temperature), a_min=None, a_max=1.0, ) accept_candidate: bool = np.random.uniform() < accept_probability if accept_candidate and verbose: print(f"Accepted {sequence} with energy {candidate_energy:.2f}.") best = (state.best_energy is None) or candidate_energy < state.best_energy return MetropolisHastingsState( program=candidate if accept_candidate else state.program, temperature=temperature, annealing_rate=state.annealing_rate, num_steps=state.num_steps + 1, candidate_energy=candidate_energy, candidate_energy_term_fn_values=candidate_energy_term_fn_values, current_energy=candidate_energy if accept_candidate else state.current_energy, current_energy_term_fn_values=candidate_energy_term_fn_values if accept_candidate else state.current_energy_term_fn_values, best_energy=candidate_energy if best else state.best_energy, best_energy_term_fn_values=candidate_energy_term_fn_values if best else state.best_energy_term_fn_values, ) def run_simulated_annealing( program: ProgramNode, initial_temperature: float, annealing_rate: float, total_num_steps: int, folding_callback: FoldingCallback, display_progress: bool = True, progress_verbose_print: bool = False, ) -> ProgramNode: # TODO(scandido): Track accept rate. state = MetropolisHastingsState( program=program, temperature=initial_temperature, annealing_rate=annealing_rate, num_steps=0, candidate_energy=None, candidate_energy_term_fn_values=None, current_energy=None, current_energy_term_fn_values=None, best_energy=None, best_energy_term_fn_values=None, ) def _generate_table(state): table = Table() table.add_column("Energy name") table.add_column("Weight") table.add_column("Candidate Value") table.add_column("Current Value") table.add_column("Best Value") if state.current_energy_term_fn_values is None: return table for (name, weight, candidate_value), (_, _, current_value), (_, _, best_value) in zip( state.candidate_energy_term_fn_values, state.current_energy_term_fn_values, state.best_energy_term_fn_values, ): table.add_row( name, f"{weight:.2f}", f"{candidate_value:.2f}", f"{current_value:.2f}", f"{best_value:.2f}", ) table.add_row( "Energy", "", f"{state.candidate_energy:.2f}", f"{state.current_energy:.2f}", f"{state.best_energy:.2f}", ) table.add_row("Iterations", "", f"{state.num_steps} / {total_num_steps}") return table with Live() as live: for _ in range(1, total_num_steps + 1): state = metropolis_hastings_step( state, folding_callback, verbose=progress_verbose_print, ) if display_progress: live.update(_generate_table(state)) return state.program