File size: 7,219 Bytes
97e363b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#Heavily Inpired by https://github.com/google-deepmind/funsearch/tree/main


import numpy as np
from typing import Callable
from .Cluster import Cluster
from .Program import Program
import ast
import astunparse
from typing import  Optional,Dict,Any
from .artifacts import AbstractArtifact
from collections.abc import Mapping, Sequence
from copy import deepcopy
from .Program import ProgramVisitor,Program,text_to_artifact
import dataclasses
import scipy
from .utils import _softmax
ScoresPerTest = Mapping



class Island:
    """ An implementation of an Island of the ProgramDB.  This code is an implementation of Funsearch (https://www.nature.com/articles/s41586-023-06924-6) and is heavily inspired by the original code (https://github.com/google-deepmind/funsearch)

    **Citation**:
    
    @Article{FunSearch2023,
        author  = {Romera-Paredes, Bernardino and Barekatain, Mohammadamin and Novikov, Alexander and Balog, Matej and Kumar, M. Pawan and Dupont, Emilien and Ruiz, Francisco J. R. and Ellenberg, Jordan and Wang, Pengming and Fawzi, Omar and Kohli, Pushmeet and Fawzi, Alhussein},
        journal = {Nature},
        title   = {Mathematical discoveries from program search with large language models},
        year    = {2023},
        doi     = {10.1038/s41586-023-06924-6}
    }
    """
    def __init__(self,
                 artifact_to_evolve_name: str,
                 artifacts_per_prompt: int,
                 temperature: float,
                 temperature_period: int,
                 template: Program,
                 reduce_score_method:  Optional[Callable] = np.mean,
                 sample_with_replacement: Optional[bool] = False):
        
        self.artifact_to_evolve_name:str = artifact_to_evolve_name
        self.artifacts_per_prompt: int = artifacts_per_prompt
        
        self.temperature: float = temperature
        self.temperature_period: int = temperature_period
        
        self.clusters: Dict[str,Cluster] = {}
        self.template: Program = template
        
        self.total_programs_on_island: int = 0
        self.reduce_score_method: Callable = reduce_score_method
        self.sample_with_replacement: bool = sample_with_replacement
        
    def register_program(self, program: AbstractArtifact ,scores_per_test: ScoresPerTest):
        """ Register a program on the island.
        
        :param program: The program to register
        :type program: AbstractArtifact
        :param scores_per_test: The scores per test of the program
        :type scores_per_test: Dict[str,Any]
        """
        scores_per_test_key = " ".join([str(key) + ":" + str(score) for key,score in scores_per_test.items()])
        
        scores_per_test_values = np.array([ score_per_test["score"] for score_per_test in scores_per_test.values()])
        
        if scores_per_test_key not in self.clusters:
            score = self.reduce_score_method(scores_per_test_values)
            self.clusters[scores_per_test_key] = Cluster(score = score,first_program=program,sample_with_replacement=self.sample_with_replacement)
        
        else:
            self.clusters[scores_per_test_key].register_program(program)
            
        self.total_programs_on_island += 1
    
    def pick_clusters(self):
        """ Pick the clusters to generate the prompt
        
        :return: The clusters and their names
        :rtype: Tuple[List[Cluster],List[str]]
        """
        cluster_keys = list(self.clusters.keys())
        clusters = [self.clusters[key] for key in cluster_keys]
        
        cluster_scores = np.array([cluster.score for cluster in clusters])
        cluster_temperature = self.temperature * (1 - (self.total_programs_on_island % self.temperature_period) / self.temperature_period)
        probs = _softmax(cluster_scores, cluster_temperature)

        #can occur at the begining when there are not many clusters
        functions_per_prompt = min(self.artifacts_per_prompt,len(self.clusters))
        select_cluster_ids = np.random.choice(len(cluster_scores),size=functions_per_prompt,p = probs,replace=self.sample_with_replacement)
        
        return [clusters[cluster_id] for cluster_id in select_cluster_ids], [cluster_keys[cluster_id] for cluster_id in select_cluster_ids]
    
    def _get_versioned_artifact_name(self,i):
        """ Get the versioned artifact name
        
        :param i: The version of the artifact
        :type i: int
        :return: The versioned artifact name
        :rtype: str
        """
        return self.artifact_to_evolve_name  + "_v" +str(i)
    
    def _generate_prompt(self,implementations: Sequence[AbstractArtifact], chosen_cluster_names: Sequence[str]):
        """ Generate the prompt
        
        :param implementations: The implementations
        :type implementations: Sequence[AbstractArtifact]
        :param chosen_cluster_names: The chosen cluster names
        :type chosen_cluster_names: Sequence[str]
        :return: The prompt
        :rtype: str
        """
        implementations = deepcopy(implementations)
        
        versioned_artifacts: list[AbstractArtifact] = []
        
        for i,implementation in enumerate(implementations):
            new_artifact_name = self._get_versioned_artifact_name(i)
            implementation.name = new_artifact_name
            score_per_test = " ".join(chosen_cluster_names[i].split(" "))
            implementation.docstring = f'Scores per test: {score_per_test}'
            if i >= 1:
                implementation.docstring += f'\nImproved version of {self._get_versioned_artifact_name(i-1)}'
            implementation = implementation.rename_artifact_calls(source_name = self.artifact_to_evolve_name, target_name = new_artifact_name)
            versioned_artifacts.append(text_to_artifact(implementation))
       
        #Create the header of the function to be generated by the LLM
        next_version = len(implementations)
        new_artifact_name = self._get_versioned_artifact_name(next_version)
        
        docstring = f'Improved version of {self._get_versioned_artifact_name(next_version-1)}'
        
        header = dataclasses.replace(
            implementations[-1],
            name=new_artifact_name,
            body='',
            docstring=docstring
        )
        
        versioned_artifacts.append(header)
        #rename the call to the target function
        prompt = dataclasses.replace(self.template, artifacts=versioned_artifacts)
        return str(prompt)
        
    def get_prompt(self):
        """ Get the prompt
        
        :return: The prompt
        :rtype: str
        """
        chosen_clusters, chosen_cluster_names = self.pick_clusters()
        
        scores = [cluster.score for cluster in chosen_clusters]
        
        indices = np.argsort(scores)
        
        sorted_implementations = [chosen_clusters[i].sample_program() for i in indices]
        sorted_cluster_names = [chosen_cluster_names[i] for i in indices]
        
        version_generated = len(sorted_implementations)
        return self._generate_prompt(sorted_implementations,sorted_cluster_names),version_generated