Shroominic commited on
Commit
2ea8b71
·
1 Parent(s): 0282efb

check file modifications prompt

Browse files
codeinterpreterapi/chains/modifications_check.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from json import JSONDecodeError
3
+ from typing import List
4
+
5
+ from langchain.base_language import BaseLanguageModel
6
+ from langchain.chat_models.openai import ChatOpenAI
7
+ from langchain.prompts.chat import (
8
+ ChatPromptTemplate,
9
+ HumanMessagePromptTemplate,
10
+ )
11
+ from langchain.schema import (
12
+ AIMessage,
13
+ OutputParserException,
14
+ SystemMessage,
15
+ )
16
+
17
+
18
+ prompt = ChatPromptTemplate(
19
+ input_variables=["code"],
20
+ messages=[
21
+ SystemMessage(content=
22
+ "The user will input some code and you will need to determine if the code makes any changes to the file system. \n"
23
+ "With changes it means creating new files or modifying exsisting ones.\n"
24
+ "Answer with a function call `determine_modifications` and list them inside.\n"
25
+ "If the code does not make any changes to the file system, still answer with the function call but return an empty list.\n",
26
+ ),
27
+ HumanMessagePromptTemplate.from_template("{code}")
28
+ ]
29
+ )
30
+
31
+ functions = [
32
+ {
33
+ "name": "determine_modifications",
34
+ "description":
35
+ "Based on code of the user determine if the code makes any changes to the file system. \n"
36
+ "With changes it means creating new files or modifying exsisting ones.\n",
37
+ "parameters": {
38
+ "type": "object",
39
+ "properties": {
40
+ "modifications": {
41
+ "type": "array",
42
+ "items": { "type": "string" },
43
+ "description": "The filenames that are modified by the code.",
44
+ },
45
+ },
46
+ "required": ["modifications"],
47
+ },
48
+ }
49
+ ]
50
+
51
+
52
+ async def get_file_modifications(
53
+ code: str,
54
+ llm: BaseLanguageModel,
55
+ retry: int = 2,
56
+ ) -> List[str] | None:
57
+ if retry < 1:
58
+ return None
59
+ messages = prompt.format_prompt(code=code).to_messages()
60
+ message = await llm.apredict_messages(messages, functions=functions)
61
+
62
+ if not isinstance(message, AIMessage):
63
+ raise OutputParserException("Expected an AIMessage")
64
+
65
+ function_call = message.additional_kwargs.get("function_call", None)
66
+
67
+ if function_call is None:
68
+ return await get_file_modifications(code, llm, retry=retry-1)
69
+ else:
70
+ function_call = json.loads(function_call["arguments"])
71
+ return function_call["modifications"]
72
+
73
+
74
+ async def test():
75
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore
76
+
77
+ code = """
78
+ import matplotlib.pyplot as plt
79
+
80
+ x = list(range(1, 11))
81
+ y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]
82
+
83
+ plt.plot(x, y, marker='o')
84
+ plt.xlabel('Index')
85
+ plt.ylabel('Value')
86
+ plt.title('Data Plot')
87
+
88
+ plt.show()
89
+ """
90
+
91
+ code2 = "import pandas as pd\n\n# Read the Excel file\ndata = pd.read_excel('Iris.xlsx')\n\n# Convert the data to CSV\ndata.to_csv('Iris.csv', index=False)"
92
+
93
+ modifications = await get_file_modifications(code2, llm)
94
+
95
+ print(modifications)
96
+
97
+
98
+ if __name__ == "__main__":
99
+ import asyncio
100
+ from dotenv import load_dotenv
101
+ load_dotenv()
102
+
103
+ asyncio.run(test())