Shroominic commited on
Commit
0282efb
·
1 Parent(s): 2160f9d

implement Session

Browse files
Files changed (1) hide show
  1. codeinterpreterapi/session.py +166 -36
codeinterpreterapi/session.py CHANGED
@@ -1,46 +1,176 @@
1
- from codeboxapi import CodeBox
2
- from promptkit import ChatSession
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  class CodeInterpreterSession():
6
 
7
- def __init__(self):
8
- self.chatgpt = ChatSession()
9
  self.codebox = CodeBox()
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- async def _init(self):
12
- await self.codebox.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- async def _close(self):
15
- await self.codebox.stop()
 
 
 
 
 
16
 
17
- async def code_decision(self, user_request: str):
18
- # check if the user wants something that requires python code execution
19
- # if yes, return "code"
20
- # if no, return "default"
21
- pass
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- async def generate_response(self, text: str, files: list[dict[str, bytes]]): # list of "file_name" x "file_content"
24
- """ Generate a Code Interpreter response based on the user's input."""
25
- if self.code_decision() == "code":
26
- pass
27
- # plan what code to write (potentially multiple steps)
28
- # code = chatgpt.run(code generation template)
29
- # codebox.run(code)
30
- # on error
31
- # check if package is required
32
- # if yes, install package
33
- # ask for analysis if the error can be fixed
34
- # if yes, continue code generation
35
- # if no, return AssistantResponse
36
- # on success
37
- # check if to output files to the user
38
- # if yes, return AssistantResponse with files
39
- # write a response based on the code execution
40
- # return AssistantResponse
41
- else:
42
- pass
43
- # return AssistantResponse
44
- pass
45
 
46
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid, base64, re
2
+ from io import BytesIO
3
+ from codeboxapi import CodeBox # type: ignore
4
+ from codeboxapi.schema import CodeBoxOutput # type: ignore
5
+ from langchain.tools import StructuredTool
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.chat_models.base import BaseChatModel
8
+ from langchain.prompts.chat import MessagesPlaceholder
9
+ from langchain.agents import AgentExecutor, BaseSingleActionAgent
10
+ from langchain.memory import ConversationBufferMemory
11
+
12
+ from codeinterpreterapi.schemas import CodeInterpreterResponse, CodeInput, File, UserRequest # type: ignore
13
+ from codeinterpreterapi.config import settings
14
+ from codeinterpreterapi.functions_agent import OpenAIFunctionsAgent
15
+ from codeinterpreterapi.prompts import code_interpreter_system_message
16
+ from codeinterpreterapi.callbacks import CodeCallbackHandler
17
+ from codeinterpreterapi.chains.modifications_check import get_file_modifications
18
+ from codeinterpreterapi.chains.remove_download_link import remove_download_link
19
 
20
 
21
  class CodeInterpreterSession():
22
 
23
+ def __init__(self, model=None, openai_api_key=None) -> None:
 
24
  self.codebox = CodeBox()
25
+ self.tools: list[StructuredTool] = self._tools()
26
+ self.llm: BaseChatModel = self._llm(model, openai_api_key)
27
+ self.agent_executor: AgentExecutor = self._agent_executor()
28
+ self.input_files: list[File] = []
29
+ self.output_files: list[File] = []
30
+
31
+ async def _init(self) -> None:
32
+ await self.codebox.astart()
33
+
34
+ async def _close(self) -> None:
35
+ await self.codebox.astop()
36
 
37
+ def _tools(self) -> list[StructuredTool]:
38
+ return [
39
+ StructuredTool(
40
+ name = "python",
41
+ description =
42
+ # TODO: variables as context to the agent
43
+ # TODO: current files as context to the agent
44
+ "Input a string of code to a python interpreter (jupyter kernel). "
45
+ "Variables are preserved between runs. ",
46
+ func = self.codebox.run,
47
+ coroutine = self.arun_handler,
48
+ args_schema = CodeInput,
49
+ ),
50
+ ]
51
+
52
+ def _llm(self, model: str | None, openai_api_key: str | None) -> BaseChatModel:
53
+ if model is None:
54
+ model = "gpt-4"
55
+
56
+ if openai_api_key is None:
57
+ if settings.OPENAI_API_KEY is None:
58
+ raise ValueError("OpenAI API key missing.")
59
+ else:
60
+ openai_api_key = settings.OPENAI_API_KEY
61
 
62
+ return ChatOpenAI(
63
+ temperature=0.03,
64
+ model=model,
65
+ openai_api_key=openai_api_key,
66
+ max_retries=3,
67
+ request_timeout=60*3,
68
+ ) # type: ignore
69
 
70
+ def _agent(self) -> BaseSingleActionAgent:
71
+ return OpenAIFunctionsAgent.from_llm_and_tools(
72
+ llm=self.llm,
73
+ tools=self.tools,
74
+ system_message=code_interpreter_system_message,
75
+ extra_prompt_messages=[MessagesPlaceholder(variable_name="memory")],
76
+ )
77
+
78
+ def _agent_executor(self) -> AgentExecutor:
79
+ return AgentExecutor.from_agent_and_tools(
80
+ agent=self._agent(),
81
+ callbacks=[CodeCallbackHandler(self)],
82
+ max_iterations=9,
83
+ tools=self.tools,
84
+ verbose=settings.VERBOSE,
85
+ memory=ConversationBufferMemory(memory_key="memory", return_messages=True),
86
+ )
87
 
88
+ async def show_code(self, code: str) -> None:
89
+ """ Callback function to show code to the user. """
90
+ if settings.VERBOSE:
91
+ print(code)
92
+
93
+ def run_handler(self, code: str):
94
+ raise NotImplementedError("Use arun_handler for now.")
95
+
96
+ async def arun_handler(self, code: str):
97
+ """ Run code in container and send the output to the user """
98
+ # TODO: upload files
99
+ output: CodeBoxOutput = await self.codebox.arun(code)
100
+
101
+ if not isinstance(output.content, str):
102
+ raise TypeError("Expected output.content to be a string.")
103
+
104
+ if output.type == "image/png":
105
+ filename = f"image-{uuid.uuid4()}.png"
106
+ file_buffer = BytesIO(base64.b64decode(output.content))
107
+ file_buffer.name = filename
108
+ # self.output_files.append(discord.File(path_like_file, filename)) TODO: add to output_files
109
+ return f"Image {filename} got send to the user."
110
 
111
+ elif output.type == "error":
112
+ # TODO: check if package install is required
113
+ # TODO: preanalyze error to optimize next code generation
114
+ print("Error:", output.content)
115
+
116
+ elif (modifications := await get_file_modifications(code, self.llm)):
117
+ for filename in modifications:
118
+ if filename in [file.name for file in self.input_files]:
119
+ continue
120
+ fileb = await self.codebox.adownload(filename)
121
+ if not fileb.content:
122
+ continue
123
+ file_buffer = BytesIO(fileb.content)
124
+ file_buffer.name = filename
125
+ self.output_files.append(File(name=filename, content=file_buffer.read()))
126
+
127
+ return output.content
128
+
129
+ async def input_handler(self, request: UserRequest):
130
+ if not request.files:
131
+ return
132
+ if not request.content:
133
+ request.content = "I uploaded, just text me back and confirm that you got the file(s)."
134
+ request.content += "\n**The user uploaded the following files: **\n"
135
+ for file in request.files:
136
+ self.input_files.append(file)
137
+ request.content += f"[Attachment: {file.name}]\n"
138
+ await self.codebox.aupload(file.name, file.content)
139
+ request.content += "**File(s) are now available in the cwd. **\n"
140
+
141
+ async def output_handler(self, final_response: str) -> CodeInterpreterResponse:
142
+ """ Embed images in the response """
143
+ for file in self.output_files:
144
+ if str(file.name) in final_response:
145
+ # rm ![Any](file.name) from the response
146
+ final_response = re.sub(rf"\n\n!\[.*\]\(.*\)", "", final_response)
147
+
148
+ if self.output_files and re.search(rf"\n\[.*\]\(.*\)", final_response):
149
+ final_response = await remove_download_link(final_response, self.llm)
150
+
151
+ return CodeInterpreterResponse(content=final_response, files=self.output_files)
152
+
153
+ async def generate_response(
154
+ self,
155
+ user_request: UserRequest,
156
+ files: list[File], # list of "file_name" x "file_content"
157
+ detailed_error: bool = False,
158
+ ) -> CodeInterpreterResponse:
159
+ """ Generate a Code Interpreter response based on the user's input."""
160
+ try:
161
+ await self.input_handler(user_request)
162
+ response = await self.agent_executor.arun(input=user_request.content)
163
+ return await self.output_handler(response)
164
+ except Exception as e:
165
+ if settings.VERBOSE:
166
+ import traceback
167
+ traceback.print_exc()
168
+ if detailed_error:
169
+ return CodeInterpreterResponse(content=
170
+ f"Error in CodeInterpreterSession: {e.__class__.__name__} - {e}"
171
+ )
172
+ else:
173
+ return CodeInterpreterResponse(content=
174
+ "Sorry, something went while generating your response."
175
+ "Please try again or restart the session."
176
+ )