File size: 7,223 Bytes
02c783d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import os
from tqdm import tqdm
from loguru import logger
import json
from dataclasses import asdict
from agents.Reflexion import Reflexion
from utils.utils import extract_function_signatures, clear_code, extract_function_calls
from prompts import prompt_for_reflection
from memories.Memory import MemoryClassMeta
from models.Base import BaseModel
from retrievers.retriever import BM25Retriever
from prompts import prompt_for_generation
from concurrent.futures import ThreadPoolExecutor, as_completed
class Reflexion_Oneshot(Reflexion):
def __init__(self, model: BaseModel, dataset, corpus_path, mem_file=None, descendant_num=1):
self.model = model
self.dataset = dataset
self.memories = []
self.instruction_retriever = BM25Retriever()
self.instruction_retriever.process(content_input_path=corpus_path)
self.code_retriever = BM25Retriever(mode="code")
self.code_retriever.process(content_input_path=corpus_path)
self.memory_init(mem_file, descendant_num)
def memory_init(self, mem_file=None, descendant_num=1):
class Memory(metaclass=MemoryClassMeta, field_names=["ps",
"err_msg",
"reflection",
"function_signatures",
"oneshot",
"pass_call",
]):
pass
if mem_file is not None:
assert mem_file.endswith(".json"), f"expect a json file, but got {mem_file} instead"
with open(mem_file, "r") as f:
input_mems = json.load(f)
assert len(input_mems) == len(self.dataset), f"expect {len(self.dataset)} samples, but got {len(input_mems)} instead"
for ps in self.dataset.problem_states:
if ps.label:
fs_mem = extract_function_signatures(ps.label)
else:
fs_mem = None
if mem_file is None:
os_mem = self.instruction_retriever.query(ps.instruction)[0]
tmp_mem = Memory(ps=ps,
err_msg=None,
reflection=None,
function_signatures=fs_mem,
oneshot=os_mem["code"],
pass_call=False,
)
else:
input_mem = input_mems[ps.filename]
tmp_mem = Memory(ps=ps,
err_msg=input_mem["err_msg"],
reflection=input_mem["reflection"],
function_signatures=fs_mem,
oneshot=input_mem["oneshot"],
pass_call=input_mem["pass_call"],
)
self.memories.append(tmp_mem)
def run(self, output_path=None, multi_thread=True, verbose=False, datalen=None, iteration_num=0, temperature=0):
data_len = datalen if datalen else len(self.dataset)
for iter in range(iteration_num):
logger.info(f"\n=== Iteration {iter} ===")
if output_path is not None:
root, extension = os.path.splitext(output_path)
iter_path = f"{root}_{iter}{extension}"
if multi_thread:
thread_num = 3
# generate solution
logger.info(f"\ngenerate solution")
with tqdm(total=data_len) as pbar:
if multi_thread:
with ThreadPoolExecutor(max_workers=thread_num) as executor:
futures = {executor.submit(self.generate_solution, mem, temperature): mem for mem in self.memories[:data_len]}
for future in as_completed(futures):
pbar.update(1)
else:
for mem in self.memories[:data_len]:
self.generate_solution(mem, temperature=temperature)
pbar.update(1)
# run scripts
logger.info(f"\nrun scripts on gpu")
for mem in tqdm(self.memories[:data_len]):
if mem.pass_call:
continue
is_pass, err_msg = self.dataset.run_single_call(mem.ps)
if not is_pass:
mem.err_msg = err_msg
# generate reflections
logger.info(f"\ngenerate reflections")
with tqdm(total=data_len) as pbar:
if multi_thread:
with ThreadPoolExecutor(max_workers=thread_num) as executor:
futures = {executor.submit(self.generate_reflexion, mem, temperature): mem for mem in self.memories[:data_len]}
for future in as_completed(futures):
pbar.update(1)
else:
for mem in self.memories[:data_len]:
self.generate_reflexion(mem, temperature=temperature)
pbar.update(1)
if output_path is not None:
self.dataset.write_file(iter_path)
def generate_solution(self, mem, temperature=0):
if mem.pass_call:
return
tab = "\n"
fss_text = "".join(f"* {sig}{tab}" for sig in mem.function_signatures)
text = prompt_for_generation.prompt.format(
instruction=mem.ps.instruction,
function_signatures=fss_text
)
if not mem.ps.solution:
text += f"\nHere is an example snippet of code: {mem.oneshot}"
else:
one_shot = self.code_retriever.query(mem.ps.solution)[0]["code"]
text += f"\nHere is an example snippet of code: {one_shot}"
text += f"\nPrevious attempt implementation:{mem.ps.solution}"
if mem.err_msg:
text += f"\nTest messages for previous attempt:{mem.err_msg}"
if mem.reflection:
text += f"\nReflection on previous attempt:{mem.reflection}"
text += "Please output the codes only without explanation, which we can run directly."
msg = [
{"role": "user", "content": text},
]
response = self.model.generate(msg, temperature=temperature)
mem.ps.solution = clear_code(response)
return
def generate_reflexion(self, mem, temperature):
if mem.pass_call:
return
reflect_txt = prompt_for_reflection.prompt.format(
problem=mem.ps.instruction,
solution=mem.ps.solution,
test_result=mem.err_msg
)
reflect_msg = [
{
"role": "user",
"content": reflect_txt
}
]
mem.reflection = self.model.generate(reflect_msg, temperature=temperature) |