codeinterpreter / codeinterpreterapi /chains /modifications_check.py
Shroominic
check file modifications prompt
2ea8b71
raw
history blame
3.1 kB
import json
from json import JSONDecodeError
from typing import List
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import (
AIMessage,
OutputParserException,
SystemMessage,
)
prompt = ChatPromptTemplate(
input_variables=["code"],
messages=[
SystemMessage(content=
"The user will input some code and you will need to determine if the code makes any changes to the file system. \n"
"With changes it means creating new files or modifying exsisting ones.\n"
"Answer with a function call `determine_modifications` and list them inside.\n"
"If the code does not make any changes to the file system, still answer with the function call but return an empty list.\n",
),
HumanMessagePromptTemplate.from_template("{code}")
]
)
functions = [
{
"name": "determine_modifications",
"description":
"Based on code of the user determine if the code makes any changes to the file system. \n"
"With changes it means creating new files or modifying exsisting ones.\n",
"parameters": {
"type": "object",
"properties": {
"modifications": {
"type": "array",
"items": { "type": "string" },
"description": "The filenames that are modified by the code.",
},
},
"required": ["modifications"],
},
}
]
async def get_file_modifications(
code: str,
llm: BaseLanguageModel,
retry: int = 2,
) -> List[str] | None:
if retry < 1:
return None
messages = prompt.format_prompt(code=code).to_messages()
message = await llm.apredict_messages(messages, functions=functions)
if not isinstance(message, AIMessage):
raise OutputParserException("Expected an AIMessage")
function_call = message.additional_kwargs.get("function_call", None)
if function_call is None:
return await get_file_modifications(code, llm, retry=retry-1)
else:
function_call = json.loads(function_call["arguments"])
return function_call["modifications"]
async def test():
llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore
code = """
import matplotlib.pyplot as plt
x = list(range(1, 11))
y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]
plt.plot(x, y, marker='o')
plt.xlabel('Index')
plt.ylabel('Value')
plt.title('Data Plot')
plt.show()
"""
code2 = "import pandas as pd\n\n# Read the Excel file\ndata = pd.read_excel('Iris.xlsx')\n\n# Convert the data to CSV\ndata.to_csv('Iris.csv', index=False)"
modifications = await get_file_modifications(code2, llm)
print(modifications)
if __name__ == "__main__":
import asyncio
from dotenv import load_dotenv
load_dotenv()
asyncio.run(test())