File size: 9,306 Bytes
793c96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from typing_extensions import List, Any, Union
from e2b_code_interpreter import Sandbox
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
import os
import sys
import base64
import gradio as gr
import time
import structlog;log=structlog.get_logger()

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # not the nicest way of handling this, but oh well...
from generation.analyzer import Analyzer, CodeAct
from retrieval.retriever import Metadata
from utils import generate_system_prompt, get_file_from_title, get_path_from_title


class SimpleAnalyzer(Analyzer):
    """
    Uploads datasets to E2B sandbox, then runs AI-generated code in the sandbox to analyze the data.
    Is able to fix errors in the code and retry the analysis.
    """
    def __init__(self, groupOwner, metadata_docs: List[Metadata], coding_client=None, timeout=60, max_steps=5):
        super().__init__(groupOwner, metadata_docs, coding_client)

        self.coding_llm = self.coding_client.with_structured_output(CodeAct, include_raw=True)
        self.sbx = Sandbox(template="ck9xjl2f121b2synr4f5", api_key=os.environ.get("E2B_API_KEY"), timeout=timeout)
        self.max_steps = max_steps

        setup_code, logs, file_names = self.init_sandbox()
        self.init_conversation(setup_code, logs, file_names)


    def finalize(self):
        log.info("Killing the sandbox...")
        self.sbx.kill()


    def run_ai_generated_code(self, ai_generated_code: str) -> tuple[bool, List[Any]]:
        """
        Run the AI-generated code in the sandbox and return the results.
        
        :param ai_generated_code: The AI-generated code to run.
        :return: A tuple containing a boolean indicating success and a list of results.
        """

        log.info("Will run following code in the sandbox:")
        log.info("========================== Code ==========================")
        log.info(ai_generated_code)
        log.info("==========================================================\n")
        execution = self.sbx.run_code(ai_generated_code)
        log.info('Code execution finished!')

        response = []

        print_log = "\n".join(execution.logs.stdout)
        response.append(print_log)
        log.info(print_log)

        # First let's check if the code ran successfully.
        if execution.error:
            log.info('AI-generated code had an error.')
            log.info(execution.error.traceback)
            response.append(execution.error.traceback)
            return False, response # stdout, stderr


        # Iterate over all the results and specifically check for png files that will represent the chart.
        result_idx = 0
        log.info(f"Number of results: {len(execution.results)}")
        for result in execution.results:
            if result.text:
                response.append(result.text)
            if result.png:
                # Save the png to a file
                # The png is in base64 format.
                with open(f'chart-{result_idx}.png', 'wb') as f:
                    f.write(base64.b64decode(result.png))
                response.append(gr.Image(f'chart-{result_idx}.png', width=800, show_label=False))
                log.info(f'Chart saved to chart-{result_idx}.png')
                result_idx += 1

        return True, response


    def init_sandbox(self) -> tuple[str, List[Any], List[str]]:
        """
        Initialize the E2B sandbox by uploading relevant datasets and running setup code.

        :return: A tuple containing the setup code, the logs from the setup code, and the names of the uploaded datasets
        """

        try:
            file_names = [m.title for m in self.metadata_docs]

            # Upload datasets to sandbox
            for file_name in file_names:
                file_path = get_path_from_title(self.groupOwner, file_name)
                with open(file_path, "rb") as f:
                    self.sbx.files.write(get_file_from_title(self.groupOwner, file_name), f)
                log.info(f"Dataset {get_file_from_title(self.groupOwner, file_name)} uploaded to sandbox!")

            setup_code = f"""
            import geopandas as gpd
            import matplotlib.pyplot as plt
            import contextily as cx
            from geopy.geocoders import Nominatim
            
            geolocator = Nominatim(user_agent="ogd4all")

            """

            cleaned_file_names = [get_file_from_title(self.groupOwner, file_name) for file_name in file_names]

            # Initial code for LLM to understand data:
            # - For every dataset, print layer names
            for idx, file_name in enumerate(cleaned_file_names):
                name = file_name.split(".")[0] # remove suffix
                setup_code += f"""
                # Get layers included in dataset {name}
                path_{idx} = "{file_name}"
                print(f"Layers in dataset {name}: {{gpd.list_layers(path_{idx})['name'].to_list()}}")
                
                """

            # Fix indentation of the code
            setup_code = "\n".join([line.strip() for line in setup_code.split("\n")])

            # Run code on sandbox
            success, logs = self.run_ai_generated_code(setup_code)
            if not success:
                log.error('Setup code had an error.')
                log.error(logs)
                self.finalize()
                sys.exit(1)
            else:
                log.info("Setup code ran successfully!")
                log.info(logs)


            return setup_code, logs, cleaned_file_names
        except Exception as e:
            log.error("Caught an exception in sandbox init: %s", e, exc_info=True, backtrace=True, diagnose=True)
            self.finalize()
            sys.exit(1)
    
    
    def init_conversation(self, setup_code, logs, file_names):
        """
        Initialize the conversation with the analyzer

        :param logs: The logs from the setup code.
        :param file_names: The names of the uploaded datasets.
        """

        system_prompt = generate_system_prompt(file_names, self.metadata_docs, setup_code, logs)

        log.info("System prompt:")
        log.info(system_prompt)

        self.messages.append(SystemMessage(system_prompt))


    def analyze(self, query: Union[str, dict]):
        # Extract text from multimodal input for backward compatibility
        if isinstance(query, dict) and "text" in query:
            text_query = query["text"]
        elif isinstance(query, str):
            text_query = query
        else:
            text_query = str(query)
            
        self.messages.append(HumanMessage(f"Request/Question: {text_query}\n")),
        start_time = time.time()

        log.info(f"User question: {text_query}")

        thought_msg = gr.ChatMessage(
            role="assistant",
            content="",
            metadata={"title": "Analyzing the data...",
                      "status": "pending"},
        )

        yield thought_msg

        steps = 0

        while steps < self.max_steps:
            steps += 1
            response = self.coding_llm.invoke(self.messages)
            codeAct = response["parsed"]
            self.total_tokens += response.usage_metadata["total_tokens"]

            # If the LLM wrapped the code with ```python ``` or ``` ```, remove it
            codeAct.code = codeAct.code.replace("```python", "").replace("```", "").strip()

            thought_msg.content += f"**Thought**: {codeAct.thought}\n"
            thought_msg.content += f"**Code**: \n```python\n{codeAct.code}\n``` \n"
            yield thought_msg

            success, response = self.run_ai_generated_code(codeAct.code)
            if not success:
                thought_msg.content += f"**Output**:\n```\n{response[0].replace("```", "")}\n```\n**Error**:\n```\n{response[1].replace("```", "")}\n```\n"
                new_msgs = [AIMessage("Thought: " + codeAct.thought), AIMessage("**Code**: " + f"```python\n{codeAct.code}\n```"), AIMessage(f"**Output**: \n{response[0]}\nError: \n{response[1]}"), HumanMessage("Please fix the error and try again.")]
                self.messages = self.messages + new_msgs
                yield thought_msg
            else:
                thought_msg.metadata["status"] = "done"
                thought_msg.metadata["title"] = "Analysis completed"
                thought_msg.metadata["duration"] = time.time() - start_time

                # Concat all text responses
                output_texts = [r for r in response if isinstance(r, str)]
                output_text = "\n".join(output_texts)

                new_msgs = [AIMessage("Thought: " + codeAct.thought), AIMessage("**Code**: " + f"```python\n{codeAct.code}\n```"), AIMessage(f"**Output**: \n{output_text}")]
                self.messages = self.messages + new_msgs

                yield [thought_msg] + response
                return
            
        thought_msg.metadata["status"] = "done"
        thought_msg.metadata["title"] = f"Maximum thinking steps ({self.max_steps}) exceeded"
        thought_msg.metadata["duration"] = time.time() - start_time
        yield [thought_msg, "Unfortunately, I was not able to find a solution to your request."]
        return