Shroominic commited on
Commit
1d63754
·
1 Parent(s): 2ea8b71

format final msg prompt

Browse files
codeinterpreterapi/chains/remove_download_link.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.base_language import BaseLanguageModel
2
+ from langchain.chat_models.openai import ChatOpenAI
3
+ from langchain.prompts.chat import (
4
+ ChatPromptTemplate,
5
+ HumanMessagePromptTemplate,
6
+ )
7
+ from langchain.schema import (
8
+ AIMessage,
9
+ OutputParserException,
10
+ SystemMessage,
11
+ HumanMessage
12
+ )
13
+
14
+
15
+ prompt = ChatPromptTemplate(
16
+ input_variables=["input_response"],
17
+ messages=[
18
+ SystemMessage(content=
19
+ "The user will send you a response and you need to remove the download link from it.\n"
20
+ "Reformat the remaining message so no whitespace or half sentences are still there.\n"
21
+ "If the response does not contain a download link, return the response as is.\n"
22
+ ),
23
+ HumanMessage(content="The dataset has been successfully converted to CSV format. You can download the converted file [here](sandbox:/Iris.csv)."),
24
+ AIMessage(content="The dataset has been successfully converted to CSV format."),
25
+ HumanMessagePromptTemplate.from_template("{input_response}")
26
+ ]
27
+ )
28
+
29
+
30
+ async def remove_download_link(
31
+ input_response: str,
32
+ llm: BaseLanguageModel,
33
+ ) -> str:
34
+ messages = prompt.format_prompt(input_response=input_response).to_messages()
35
+ message = await llm.apredict_messages(messages)
36
+
37
+ if not isinstance(message, AIMessage):
38
+ raise OutputParserException("Expected an AIMessage")
39
+
40
+ return message.content
41
+
42
+
43
+ async def test():
44
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore
45
+
46
+ example = "I have created the plot to your dataset.\n\nLink to the file [here](sandbox:/plot.png)."
47
+
48
+ modifications = await remove_download_link(example, llm)
49
+
50
+ print(modifications)
51
+
52
+
53
+ if __name__ == "__main__":
54
+ import asyncio
55
+ import dotenv
56
+ dotenv.load_dotenv()
57
+
58
+ asyncio.run(test())