File size: 6,807 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Acknowledgement: Modified from AFlow (https://github.com/geekan/MetaGPT/blob/main/metagpt/ext/aflow/scripts/optimizer_utils/graph_utils.py) under MIT License 

import os 
import re
import time
import traceback
from typing import List
from pathlib import Path
from ...core.logging import logger
from ...prompts.optimizers.aflow_optimizer import (
    WORKFLOW_INPUT, 
    WORKFLOW_OPTIMIZE_PROMPT, 
    WORKFLOW_CUSTOM_USE, 
    WORKFLOW_TEMPLATE
)
from ...models.base_model import BaseLLM 
from ...workflow.operators import (
    Operator, Custom, CustomCodeGenerate, 
    ScEnsemble, Test, AnswerGenerate, QAScEnsemble, Programmer
)

OPERATOR_MAP = {
    "Custom": Custom,
    "CustomCodeGenerate": CustomCodeGenerate,
    "ScEnsemble": ScEnsemble,
    "Test": Test,
    "AnswerGenerate": AnswerGenerate,
    "QAScEnsemble": QAScEnsemble,
    "Programmer": Programmer
}


class GraphUtils:

    def __init__(self, root_path: str):
        self.root_path = root_path

    def create_round_directory(self, graph_path: str, round_number: int) -> str:
        directory = os.path.join(graph_path, f"round_{round_number}")
        os.makedirs(directory, exist_ok=True)
        return directory

    def load_graph(self, round_number: int, workflows_path: str):
        workflows_path = workflows_path.replace("\\", ".").replace("/", ".")
        graph_module_name = f"{workflows_path}.round_{round_number}.graph"
        try:
            graph_module = __import__(graph_module_name, fromlist=[""])
            graph_class = getattr(graph_module, "Workflow")
            return graph_class
        except ImportError as e:
            logger.info(f"Error loading graph for round {round_number}: {e}")
            raise

    def read_graph_files(self, round_number: int, workflows_path: str):
        prompt_file_path = os.path.join(workflows_path, f"round_{round_number}", "prompt.py")
        graph_file_path = os.path.join(workflows_path, f"round_{round_number}", "graph.py")

        try:
            with open(prompt_file_path, "r", encoding="utf-8") as file:
                prompt_content = file.read()
            with open(graph_file_path, "r", encoding="utf-8") as file:
                graph_content = file.read()
        except FileNotFoundError as e:
            logger.info(f"Error: File not found for round {round_number}: {e}")
            raise
        except Exception as e:
            logger.info(f"Error loading prompt for round {round_number}: {e}")
            raise
        return prompt_content, graph_content

    def extract_solve_graph(self, graph_load: str) -> List[str]:
        pattern = r"class Workflow:.+"
        return re.findall(pattern, graph_load, re.DOTALL)

    def load_operators_description(self, operators: List[str], llm: BaseLLM) -> str:

        operators_description = ""
        for id, operator in enumerate(operators):
            operator_description = self._load_operator_description(id + 1, operator, llm)
            operators_description += f"{operator_description}\n"
        return operators_description

    def _load_operator_description(self, id: int, operator_name: str, llm: BaseLLM) -> str:
        if operator_name not in OPERATOR_MAP:
            raise ValueError(f"Operator {operator_name} not Found in OPERATOR_MAP! Available operators: {OPERATOR_MAP.keys()}")
        operator: Operator = OPERATOR_MAP[operator_name](llm=llm)
        return f"{id}. {operator_name}: {operator.description}, with interface {operator.interface})."

    def create_graph_optimize_prompt(
        self,
        experience: str,
        score: float,
        graph: str,
        prompt: str,
        operator_description: str,
        type: str,
        log_data: str,
    ) -> str:
        graph_input = WORKFLOW_INPUT.format(
            experience=experience,
            score=score,
            graph=graph,
            prompt=prompt,
            operator_description=operator_description,
            type=type,
            log=log_data,
        )
        graph_system = WORKFLOW_OPTIMIZE_PROMPT.format(type=type)
        return graph_input + WORKFLOW_CUSTOM_USE + graph_system

    def get_graph_optimize_response(self, graph_optimize_node):
        max_retries = 5
        retries = 0

        while retries < max_retries:
            try:
                response = graph_optimize_node.instruct_content.model_dump()
                return response
            except Exception as e:
                retries += 1
                logger.info(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})")
                if retries == max_retries:
                    logger.info("Maximum retries reached. Skipping this sample.")
                    break
                traceback.print_exc()
                time.sleep(5)
        return None

    def write_graph_files(self, directory: str, response: dict):
        
        graph = WORKFLOW_TEMPLATE.format(graph=response["graph"])        
        with open(os.path.join(directory, "graph.py"), "w", encoding="utf-8") as file:
            file.write(graph)
        with open(os.path.join(directory, "prompt.py"), "w", encoding="utf-8") as file:
            prompt = response["prompt"].replace("prompt_custom.", "")
            file.write(prompt)
        with open(os.path.join(directory, "__init__.py"), "w", encoding="utf-8") as file:
            file.write("")
        self.update_prompt_import(os.path.join(directory, "graph.py"), directory)
    
    def update_prompt_import(self, graph_file: str, prompt_folder: str):

        project_root = Path(os.getcwd())
        prompt_folder_path = Path(prompt_folder)

        if not prompt_folder_path.is_absolute():
            prompt_folder_full_path = Path(os.path.join(project_root, prompt_folder))
            if not prompt_folder_full_path.exists():
                raise ValueError(f"Prompt folder {prompt_folder_full_path} does not exist!")
            prompt_folder_path = prompt_folder_full_path
        
        try:
            relative_path = prompt_folder_path.relative_to(project_root)
        except ValueError:
            raise ValueError(f"Prompt folder {prompt_folder} must be within the project directory")

        import_path = str(relative_path).replace(os.sep, ".")
        if import_path.startswith("."):
            import_path = import_path[1:]
        
        with open(graph_file, "r", encoding="utf-8") as file:
            graph_content = file.read()

        # 在graph_content中找到import语句
        pattern = r'import .*?\.prompt as prompt_custom' 
        replacement = f'import {import_path}.prompt as prompt_custom'
        new_content = re.sub(pattern, replacement, graph_content)

        with open(graph_file, "w", encoding="utf-8") as file:
            file.write(new_content)