File size: 3,103 Bytes
2ea8b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
from json import JSONDecodeError
from typing import List

from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.schema import (
    AIMessage,
    OutputParserException,
    SystemMessage,
)


prompt = ChatPromptTemplate(
    input_variables=["code"],
    messages=[
        SystemMessage(content=
            "The user will input some code and you will need to determine if the code makes any changes to the file system. \n"
            "With changes it means creating new files or modifying exsisting ones.\n"
            "Answer with a function call `determine_modifications` and list them inside.\n"
            "If the code does not make any changes to the file system, still answer with the function call but return an empty list.\n",
        ),
        HumanMessagePromptTemplate.from_template("{code}")
    ]
)

functions = [
    {
        "name": "determine_modifications",
        "description": 
            "Based on code of the user determine if the code makes any changes to the file system. \n"
            "With changes it means creating new files or modifying exsisting ones.\n",
        "parameters": {
            "type": "object",
            "properties": {
                "modifications": {
                    "type": "array",
                    "items": { "type": "string" },
                    "description": "The filenames that are modified by the code.",
                },
            },
            "required": ["modifications"],
        },
    }
]


async def get_file_modifications(
    code: str, 
    llm: BaseLanguageModel,
    retry: int = 2,
) -> List[str] | None:
    if retry < 1: 
        return None
    messages = prompt.format_prompt(code=code).to_messages()
    message = await llm.apredict_messages(messages, functions=functions)
    
    if not isinstance(message, AIMessage):
        raise OutputParserException("Expected an AIMessage")
    
    function_call = message.additional_kwargs.get("function_call", None)
    
    if function_call is None:
        return await get_file_modifications(code, llm, retry=retry-1)
    else: 
        function_call = json.loads(function_call["arguments"])
        return function_call["modifications"]
    

async def test():
    llm = ChatOpenAI(model="gpt-3.5-turbo-0613")  # type: ignore
    
    code = """
    import matplotlib.pyplot as plt

    x = list(range(1, 11))
    y = [29, 39, 23, 32, 4, 43, 43, 23, 43, 77]

    plt.plot(x, y, marker='o')
    plt.xlabel('Index')
    plt.ylabel('Value')
    plt.title('Data Plot')

    plt.show()
    """

    code2 = "import pandas as pd\n\n# Read the Excel file\ndata = pd.read_excel('Iris.xlsx')\n\n# Convert the data to CSV\ndata.to_csv('Iris.csv', index=False)"
    
    modifications = await get_file_modifications(code2, llm)
    
    print(modifications)


if __name__ == "__main__":
    import asyncio
    from dotenv import load_dotenv
    load_dotenv()
    
    asyncio.run(test())