nbaldwin's picture
first version FunSearch
97e363b
#Heavily Inpired by https://github.com/google-deepmind/funsearch/tree/main
import numpy as np
from typing import Callable
from .Cluster import Cluster
from .Program import Program
import ast
import astunparse
from typing import Optional,Dict,Any
from .artifacts import AbstractArtifact
from collections.abc import Mapping, Sequence
from copy import deepcopy
from .Program import ProgramVisitor,Program,text_to_artifact
import dataclasses
import scipy
from .utils import _softmax
ScoresPerTest = Mapping
class Island:
""" An implementation of an Island of the ProgramDB. This code is an implementation of Funsearch (https://www.nature.com/articles/s41586-023-06924-6) and is heavily inspired by the original code (https://github.com/google-deepmind/funsearch)
**Citation**:
@Article{FunSearch2023,
author = {Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Balog, Matej and Kumar, M. Pawan and Dupont, Emilien and Ruiz, Francisco J. R. and Ellenberg, Jordan and Wang, Pengming and Fawzi, Omar and Kohli, Pushmeet and Fawzi, Alhussein},
journal = {Nature},
title = {Mathematical discoveries from program search with large language models},
year = {2023},
doi = {10.1038/s41586-023-06924-6}
}
"""
def __init__(self,
artifact_to_evolve_name: str,
artifacts_per_prompt: int,
temperature: float,
temperature_period: int,
template: Program,
reduce_score_method: Optional[Callable] = np.mean,
sample_with_replacement: Optional[bool] = False):
self.artifact_to_evolve_name:str = artifact_to_evolve_name
self.artifacts_per_prompt: int = artifacts_per_prompt
self.temperature: float = temperature
self.temperature_period: int = temperature_period
self.clusters: Dict[str,Cluster] = {}
self.template: Program = template
self.total_programs_on_island: int = 0
self.reduce_score_method: Callable = reduce_score_method
self.sample_with_replacement: bool = sample_with_replacement
def register_program(self, program: AbstractArtifact ,scores_per_test: ScoresPerTest):
""" Register a program on the island.
:param program: The program to register
:type program: AbstractArtifact
:param scores_per_test: The scores per test of the program
:type scores_per_test: Dict[str,Any]
"""
scores_per_test_key = " ".join([str(key) + ":" + str(score) for key,score in scores_per_test.items()])
scores_per_test_values = np.array([ score_per_test["score"] for score_per_test in scores_per_test.values()])
if scores_per_test_key not in self.clusters:
score = self.reduce_score_method(scores_per_test_values)
self.clusters[scores_per_test_key] = Cluster(score = score,first_program=program,sample_with_replacement=self.sample_with_replacement)
else:
self.clusters[scores_per_test_key].register_program(program)
self.total_programs_on_island += 1
def pick_clusters(self):
""" Pick the clusters to generate the prompt
:return: The clusters and their names
:rtype: Tuple[List[Cluster],List[str]]
"""
cluster_keys = list(self.clusters.keys())
clusters = [self.clusters[key] for key in cluster_keys]
cluster_scores = np.array([cluster.score for cluster in clusters])
cluster_temperature = self.temperature * (1 - (self.total_programs_on_island % self.temperature_period) / self.temperature_period)
probs = _softmax(cluster_scores, cluster_temperature)
#can occur at the begining when there are not many clusters
functions_per_prompt = min(self.artifacts_per_prompt,len(self.clusters))
select_cluster_ids = np.random.choice(len(cluster_scores),size=functions_per_prompt,p = probs,replace=self.sample_with_replacement)
return [clusters[cluster_id] for cluster_id in select_cluster_ids], [cluster_keys[cluster_id] for cluster_id in select_cluster_ids]
def _get_versioned_artifact_name(self,i):
""" Get the versioned artifact name
:param i: The version of the artifact
:type i: int
:return: The versioned artifact name
:rtype: str
"""
return self.artifact_to_evolve_name + "_v" +str(i)
def _generate_prompt(self,implementations: Sequence[AbstractArtifact], chosen_cluster_names: Sequence[str]):
""" Generate the prompt
:param implementations: The implementations
:type implementations: Sequence[AbstractArtifact]
:param chosen_cluster_names: The chosen cluster names
:type chosen_cluster_names: Sequence[str]
:return: The prompt
:rtype: str
"""
implementations = deepcopy(implementations)
versioned_artifacts: list[AbstractArtifact] = []
for i,implementation in enumerate(implementations):
new_artifact_name = self._get_versioned_artifact_name(i)
implementation.name = new_artifact_name
score_per_test = " ".join(chosen_cluster_names[i].split(" "))
implementation.docstring = f'Scores per test: {score_per_test}'
if i >= 1:
implementation.docstring += f'\nImproved version of {self._get_versioned_artifact_name(i-1)}'
implementation = implementation.rename_artifact_calls(source_name = self.artifact_to_evolve_name, target_name = new_artifact_name)
versioned_artifacts.append(text_to_artifact(implementation))
#Create the header of the function to be generated by the LLM
next_version = len(implementations)
new_artifact_name = self._get_versioned_artifact_name(next_version)
docstring = f'Improved version of {self._get_versioned_artifact_name(next_version-1)}'
header = dataclasses.replace(
implementations[-1],
name=new_artifact_name,
body='',
docstring=docstring
)
versioned_artifacts.append(header)
#rename the call to the target function
prompt = dataclasses.replace(self.template, artifacts=versioned_artifacts)
return str(prompt)
def get_prompt(self):
""" Get the prompt
:return: The prompt
:rtype: str
"""
chosen_clusters, chosen_cluster_names = self.pick_clusters()
scores = [cluster.score for cluster in chosen_clusters]
indices = np.argsort(scores)
sorted_implementations = [chosen_clusters[i].sample_program() for i in indices]
sorted_cluster_names = [chosen_cluster_names[i] for i in indices]
version_generated = len(sorted_implementations)
return self._generate_prompt(sorted_implementations,sorted_cluster_names),version_generated