JakubGetMe commited on
Commit
12a62f3
·
1 Parent(s): 44422c1

made it possible to pass additional tools to CodeInterpreterSession

Browse files
codeinterpreterapi/session.py CHANGED
@@ -3,7 +3,7 @@ from io import BytesIO
3
  from typing import Optional
4
  from codeboxapi import CodeBox # type: ignore
5
  from codeboxapi.schema import CodeBoxOutput # type: ignore
6
- from langchain.tools import StructuredTool
7
  from langchain.chat_models import ChatOpenAI
8
  from langchain.chat_models.base import BaseChatModel
9
  from langchain.prompts.chat import MessagesPlaceholder
@@ -20,10 +20,16 @@ from codeinterpreterapi.chains.remove_download_link import remove_download_link
20
 
21
 
22
  class CodeInterpreterSession:
23
- def __init__(self, model=None, openai_api_key=settings.OPENAI_API_KEY, verbose=settings.VERBOSE) -> None:
 
 
 
 
 
 
24
  self.codebox = CodeBox()
25
  self.verbose = verbose
26
- self.tools: list[StructuredTool] = self._tools()
27
  self.llm: BaseChatModel = self._llm(model, openai_api_key)
28
  self.agent_executor: AgentExecutor = self._agent_executor()
29
  self.input_files: list[File] = []
@@ -32,8 +38,9 @@ class CodeInterpreterSession:
32
  async def astart(self) -> None:
33
  await self.codebox.astart()
34
 
35
- def _tools(self) -> list[StructuredTool]:
36
- return [
 
37
  StructuredTool(
38
  name="python",
39
  description=
 
3
  from typing import Optional
4
  from codeboxapi import CodeBox # type: ignore
5
  from codeboxapi.schema import CodeBoxOutput # type: ignore
6
+ from langchain.tools import StructuredTool, BaseTool
7
  from langchain.chat_models import ChatOpenAI
8
  from langchain.chat_models.base import BaseChatModel
9
  from langchain.prompts.chat import MessagesPlaceholder
 
20
 
21
 
22
  class CodeInterpreterSession:
23
+ def __init__(
24
+ self,
25
+ model=None,
26
+ openai_api_key=settings.OPENAI_API_KEY,
27
+ verbose=settings.VERBOSE,
28
+ tools: list[BaseTool] = None
29
+ ) -> None:
30
  self.codebox = CodeBox()
31
  self.verbose = verbose
32
+ self.tools: list[BaseTool] = self._tools(tools)
33
  self.llm: BaseChatModel = self._llm(model, openai_api_key)
34
  self.agent_executor: AgentExecutor = self._agent_executor()
35
  self.input_files: list[File] = []
 
38
  async def astart(self) -> None:
39
  await self.codebox.astart()
40
 
41
+ def _tools(self, additional_tools: list[BaseTool] = None) -> list[BaseTool]:
42
+ additional_tools = additional_tools or []
43
+ return additional_tools + [
44
  StructuredTool(
45
  name="python",
46
  description=
examples/use_additional_tools.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The exciting part about this example is
3
+ that the code interpreter has internet access
4
+ so it can download the bitcoin chart from yahoo finance
5
+ and plot it for you
6
+ """
7
+ import csv
8
+ import io
9
+ from datetime import datetime
10
+ from typing import Any
11
+
12
+ from langchain.tools import tool, BaseTool
13
+
14
+ from codeinterpreterapi import CodeInterpreterSession
15
+
16
+
17
+ class ExampleKnowledgeBaseTool(BaseTool):
18
+ name = "salary_database"
19
+ description = "Use to get salary data of company employees"
20
+
21
+ def _run(self, *args, **kwargs):
22
+ raise NotImplementedError()
23
+
24
+ async def _arun(self, *args, **kwargs: Any) -> Any:
25
+ f = io.StringIO()
26
+ writer = csv.writer(f)
27
+ writer.writerow(['month', 'employee', 'salary'])
28
+ writer.writerow(['march 2022', 'Jan', '1200'])
29
+ writer.writerow(['march 2022', 'Ola', '1500'])
30
+ writer.writerow(['april 2022', 'Jan', '1800'])
31
+ writer.writerow(['april 2022', 'Ola', '2000'])
32
+ return f.getvalue()
33
+
34
+
35
+ async def main():
36
+ async with CodeInterpreterSession(tools=[ExampleKnowledgeBaseTool()]) as session:
37
+ response = await session.generate_response(
38
+ f"Plot chart of company employee salaries"
39
+ )
40
+
41
+ print("AI: ", response.content)
42
+ for file in response.files:
43
+ file.show_image()
44
+
45
+
46
+ if __name__ == "__main__":
47
+ import asyncio
48
+
49
+ asyncio.run(main())