Spaces:
Runtime error
Runtime error
Shroominic
commited on
Commit
·
2ea8b71
1
Parent(s):
0282efb
check file modifications prompt
Browse files
codeinterpreterapi/chains/modifications_check.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from json import JSONDecodeError
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from langchain.base_language import BaseLanguageModel
|
| 6 |
+
from langchain.chat_models.openai import ChatOpenAI
|
| 7 |
+
from langchain.prompts.chat import (
|
| 8 |
+
ChatPromptTemplate,
|
| 9 |
+
HumanMessagePromptTemplate,
|
| 10 |
+
)
|
| 11 |
+
from langchain.schema import (
|
| 12 |
+
AIMessage,
|
| 13 |
+
OutputParserException,
|
| 14 |
+
SystemMessage,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
prompt = ChatPromptTemplate(
|
| 19 |
+
input_variables=["code"],
|
| 20 |
+
messages=[
|
| 21 |
+
SystemMessage(content=
|
| 22 |
+
"The user will input some code and you will need to determine if the code makes any changes to the file system. \n"
|
| 23 |
+
"With changes it means creating new files or modifying exsisting ones.\n"
|
| 24 |
+
"Answer with a function call `determine_modifications` and list them inside.\n"
|
| 25 |
+
"If the code does not make any changes to the file system, still answer with the function call but return an empty list.\n",
|
| 26 |
+
),
|
| 27 |
+
HumanMessagePromptTemplate.from_template("{code}")
|
| 28 |
+
]
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
functions = [
|
| 32 |
+
{
|
| 33 |
+
"name": "determine_modifications",
|
| 34 |
+
"description":
|
| 35 |
+
"Based on code of the user determine if the code makes any changes to the file system. \n"
|
| 36 |
+
"With changes it means creating new files or modifying exsisting ones.\n",
|
| 37 |
+
"parameters": {
|
| 38 |
+
"type": "object",
|
| 39 |
+
"properties": {
|
| 40 |
+
"modifications": {
|
| 41 |
+
"type": "array",
|
| 42 |
+
"items": { "type": "string" },
|
| 43 |
+
"description": "The filenames that are modified by the code.",
|
| 44 |
+
},
|
| 45 |
+
},
|
| 46 |
+
"required": ["modifications"],
|
| 47 |
+
},
|
| 48 |
+
}
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
async def get_file_modifications(
|
| 53 |
+
code: str,
|
| 54 |
+
llm: BaseLanguageModel,
|
| 55 |
+
retry: int = 2,
|
| 56 |
+
) -> List[str] | None:
|
| 57 |
+
if retry < 1:
|
| 58 |
+
return None
|
| 59 |
+
messages = prompt.format_prompt(code=code).to_messages()
|
| 60 |
+
message = await llm.apredict_messages(messages, functions=functions)
|
| 61 |
+
|
| 62 |
+
if not isinstance(message, AIMessage):
|
| 63 |
+
raise OutputParserException("Expected an AIMessage")
|
| 64 |
+
|
| 65 |
+
function_call = message.additional_kwargs.get("function_call", None)
|
| 66 |
+
|
| 67 |
+
if function_call is None:
|
| 68 |
+
return await get_file_modifications(code, llm, retry=retry-1)
|
| 69 |
+
else:
|
| 70 |
+
function_call = json.loads(function_call["arguments"])
|
| 71 |
+
return function_call["modifications"]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
async def test():
|
| 75 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore
|
| 76 |
+
|
| 77 |
+
code = """
|
| 78 |
+
import matplotlib.pyplot as plt
|
| 79 |
+
|
| 80 |
+
x = list(range(1, 11))
|
| 81 |
+
y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]
|
| 82 |
+
|
| 83 |
+
plt.plot(x, y, marker='o')
|
| 84 |
+
plt.xlabel('Index')
|
| 85 |
+
plt.ylabel('Value')
|
| 86 |
+
plt.title('Data Plot')
|
| 87 |
+
|
| 88 |
+
plt.show()
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
code2 = "import pandas as pd\n\n# Read the Excel file\ndata = pd.read_excel('Iris.xlsx')\n\n# Convert the data to CSV\ndata.to_csv('Iris.csv', index=False)"
|
| 92 |
+
|
| 93 |
+
modifications = await get_file_modifications(code2, llm)
|
| 94 |
+
|
| 95 |
+
print(modifications)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
import asyncio
|
| 100 |
+
from dotenv import load_dotenv
|
| 101 |
+
load_dotenv()
|
| 102 |
+
|
| 103 |
+
asyncio.run(test())
|