runsdata commited on
Commit
8181e20
·
1 Parent(s): fb9f007

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import openai
4
+ import gradio as gr
5
+ from langchain.embeddings.openai import OpenAIEmbeddings
6
+ from langchain.vectorstores import Chroma
7
+ from langchain.chat_models import ChatOpenAI
8
+ from langchain.schema import AIMessage, HumanMessage, SystemMessage
9
+
10
+ # Sets up OpenAI embeddings model
11
+ embeddings = OpenAIEmbeddings()
12
+
13
+ # Loads database from persisted directory
14
+ db_directory = "chroma_db"
15
+ db = Chroma(persist_directory=db_directory, embedding_function=embeddings)
16
+
17
+ # Retrieves relevant documents based on a similarity search
18
+ retriever = db.as_retriever(search_type='similarity', search_kwargs={"k":3})
19
+
20
+ with open('system_prompt.txt', 'r') as file:
21
+ ORIG_SYSTEM_MESSAGE_PROMPT = file.read()
22
+
23
+ openai.api_key = os.getenv("OPENAI_API_KEY")
24
+
25
+ chat = ChatOpenAI(model_name="gpt-4",temperature=0)
26
+
27
+ # Here is the langchain
28
+ def predict(history, input):
29
+ context = retriever.get_relevant_documents(input)
30
+ print(context) #For debugging
31
+ history_langchain_format = []
32
+ history_langchain_format.append(SystemMessage(content=f"{ORIG_SYSTEM_MESSAGE_PROMPT}"))
33
+ for human, ai in history:
34
+ history_langchain_format.append(HumanMessage(content=human))
35
+ history_langchain_format.append(AIMessage(content=ai))
36
+ history_langchain_format.append(HumanMessage(content=input))
37
+ history_langchain_format.append(SystemMessage(content=f"Here are some stories the user may like: {context}"))
38
+
39
+ gpt_response = chat(history_langchain_format)
40
+
41
+ # Extract pairs of HumanMessage and AIMessage
42
+ pairs = []
43
+ for i in range(len(history_langchain_format)):
44
+ if isinstance(history_langchain_format[i], HumanMessage) and (i+1 < len(history_langchain_format)) and isinstance(history_langchain_format[i+1], AIMessage):
45
+ pairs.append((history_langchain_format[i].content, history_langchain_format[i+1].content))
46
+
47
+ # Add new AI response to the pairs for subsequent interactions
48
+ pairs.append((input, gpt_response.content))
49
+
50
+ return pairs
51
+
52
+ # Function to handle user message
53
+ def user(user_message, chatbot_history):
54
+ return "", chatbot_history + [[user_message, ""]]
55
+
56
+ # Function to handle AI's response
57
+ def bot(chatbot_history):
58
+ user_message = chatbot_history[-1][0] #This line is because we cleared the user_message previously in the user function above
59
+ # Call the predict function to get the AI's response
60
+ pairs = predict(chatbot_history, user_message)
61
+ _, ai_response = pairs[-1] # Get the latest response
62
+
63
+ response_in_progress = ""
64
+ for character in ai_response:
65
+ response_in_progress += character
66
+ chatbot_history[-1][1] = response_in_progress
67
+ time.sleep(0.05)
68
+ yield chatbot_history
69
+
70
+ # This is a function to do something with the voted information
71
+ def vote(data: gr.LikeData):
72
+ if data.liked:
73
+ print("You upvoted this response: " + data.value)
74
+ else:
75
+ print("You downvoted this response: " + data.value)
76
+ with open("logs.txt", "a") as text_file:
77
+ print(f"Disliked content: {data.value}", file=text_file)
78
+
79
+ # The Gradio App interface
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("""<h1><center>Technocomplex Bot</center></h1>""")
82
+ chatbot = gr.Chatbot(label="Technocomplex Bot")
83
+ textbox = gr.Textbox(label="Start chatting here!")
84
+ clear = gr.Button("Clear")
85
+
86
+ # Chain user and bot functions with `.then()`
87
+ textbox.submit(user, [textbox, chatbot], [textbox, chatbot], queue=False).then(
88
+ bot, chatbot, chatbot,
89
+ )
90
+ clear.click(lambda: None, None, chatbot, queue=False)
91
+ chatbot.like(vote, None, None)
92
+
93
+ # Enable queuing
94
+ demo.queue()
95
+ demo.launch(debug=True, share=True)