solalatus commited on
Commit
4bfaa77
·
1 Parent(s): 371bd01

initial commit

Browse files
Files changed (2) hide show
  1. agent.py +104 -0
  2. app.py +40 -0
agent.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from dotenv import dotenv_values
3
+ import os
4
+ from tqdm.auto import tqdm
5
+ import pinecone
6
+ from langchain.embeddings import OpenAIEmbeddings
7
+ from pinecone_text.sparse import BM25Encoder
8
+ from langchain.retrievers import PineconeHybridSearchRetriever
9
+ from langchain.chat_models import ChatOpenAI
10
+ from langchain.agents import initialize_agent, Tool
11
+ from langchain.tools.base import BaseTool
12
+ from langchain.agents import AgentType
13
+ from langchain.agents.react.base import DocstoreExplorer
14
+ from langchain import LLMMathChain
15
+ from typing import Union
16
+ from langchain.memory import ConversationBufferWindowMemory
17
+
18
+
19
+ load_dotenv()
20
+ config = dotenv_values(".env")
21
+
22
+
23
+
24
+ class CalculatorTool(BaseTool):
25
+ name = "CalculatorTool"
26
+
27
+ description = """
28
+ Useful for when you need to execute specific math calculations.
29
+ This tool is only for math calculations and nothing else.
30
+ Formulate the input as python code.
31
+ """
32
+
33
+ def _run(self, question: str):
34
+ return exec(question)
35
+
36
+ def _arun(self, value: Union[int, float]):
37
+ raise NotImplementedError("This tool does not support async")
38
+
39
+
40
+
41
+ class QMLAgent():
42
+
43
+ def __init__(self):
44
+
45
+ pinecone.init(api_key=config["PINECONE_API_KEY"], environment=config["PINECONE_REGION"])
46
+
47
+ index = pinecone.Index(config["INDEX_NAME"])
48
+
49
+ embeddings = OpenAIEmbeddings()
50
+
51
+
52
+ bm25_encoder = BM25Encoder()
53
+ bm25_encoder.load(config["BM25_FILENAME"])
54
+
55
+ retriever = PineconeHybridSearchRetriever(
56
+ embeddings=embeddings,
57
+ sparse_encoder=bm25_encoder,
58
+ index=index,
59
+ top_k=config["TOP_K"])
60
+
61
+ llm = ChatOpenAI(model_name=config["CHAT_MODEL"])
62
+
63
+ math_tool = CalculatorTool()
64
+
65
+ tools = [
66
+ Tool(
67
+ name="Search",
68
+ func=retriever.get_relevant_documents,
69
+ description="You have to use this to search for knowledge about quantum computing and quantum machine learning.",
70
+ ),
71
+ Tool.from_function(
72
+ name="Match calculation",
73
+ func=math_tool._run,
74
+ description="""
75
+ Useful for when you need to execute specific math calculations.
76
+ This tool is only for math calculations and nothing else.
77
+ Formulate the input as python code, always use explicit printing for results!.
78
+ """
79
+ #return_direct=False
80
+
81
+ ),
82
+ ]
83
+
84
+ memory = ConversationBufferWindowMemory(k=config["MEMORY_LENGTH"], memory_key="chat_history", return_messages=True)
85
+
86
+
87
+ self.agent_chain = initialize_agent(
88
+ tools,
89
+ llm,
90
+ agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
91
+ verbose=True,
92
+ return_intermediate_steps=False,
93
+ memory=memory,
94
+ handle_parsing_errors="Always provide only code answer in a single block parseable in JSON, nothing more! Delete your other remarks, just output the pure code! Always produce runnable code, no parts left blank!"
95
+ )
96
+
97
+ def run(self, question):
98
+ return self.agent_chain.run(question)
99
+
100
+ if __name__ == '__main__':
101
+ agent = QMLAgent()
102
+ question = "What is eigenvector for the matrix [[1,2,3],[4,5,6],[7,8,9]] raised to the second power?"
103
+
104
+ print(agent.run(question))
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dotenv import dotenv_values
3
+
4
+ from agent import QMLAgent
5
+
6
+
7
+ config = dotenv_values(".env")
8
+
9
+ agent = QMLAgent()
10
+
11
+
12
+ with gr.Blocks(theme='snehilsanyal/scikit-learn') as demo:
13
+ #with gr.Tab("QA"):
14
+ chatbot = gr.Chatbot(label="QML Class Conversation Agent Demo")
15
+ msg = gr.Textbox()
16
+ clear = gr.Button("Clear")
17
+
18
+ def respond(user_message, chat_history):#, progress=gr.Progress()):
19
+ global agent
20
+ bot_message = agent.run(user_message)
21
+ chat_history.append((user_message, bot_message))
22
+
23
+ return "", chat_history
24
+
25
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
26
+ clear.click(lambda: None, None, chatbot, queue=False)
27
+
28
+
29
+
30
+ db_login = {
31
+ config["USER_NAME"]: config["USER_PWD"]
32
+ }
33
+
34
+ def _myauth(username, password):
35
+ if db_login.get(username) == password:
36
+ return True
37
+ return False
38
+
39
+ #demo.queue(concurrency_count=10).launch(server_port=7860, server_name='0.0.0.0')
40
+ demo.queue(concurrency_count=10).launch(auth=_myauth, server_port=7860, server_name='0.0.0.0')