gerglitzen commited on
Commit
08dd874
·
1 Parent(s): 537df3d
Files changed (1) hide show
  1. main.py +184 -0
main.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from langchain.callbacks.base import BaseCallbackHandler
4
+ from collections.abc import Generator
5
+ from queue import Queue, Empty
6
+ from threading import Thread
7
+
8
+
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv()
12
+
13
+ from langchain import PromptTemplate
14
+ from langchain.chains import LLMChain
15
+ from langchain.chat_models import ChatOpenAI
16
+ import pinecone
17
+ from langchain.embeddings import OpenAIEmbeddings
18
+
19
+
20
+ OPENAI_API_KEY=os.environ["OPENAI_API_KEY"]
21
+ PINECONE_API_KEY=os.environ["PINECONE_API_KEY"]
22
+ PINECONE_ENV=os.environ["PINECONE_ENV"]
23
+ PINECONE_INDEX=os.environ["PINECONE_INDEX"]
24
+
25
+ class QueueCallback(BaseCallbackHandler):
26
+ """Callback handler for streaming LLM responses to a queue."""
27
+
28
+ def __init__(self, q):
29
+ self.q = q
30
+
31
+ def on_llm_new_token(self, token: str, **kwargs: any) -> None:
32
+ self.q.put(token)
33
+
34
+ def on_llm_end(self, *args, **kwargs: any) -> None:
35
+ return self.q.empty()
36
+
37
+ # TOOL
38
+ #####################################################################
39
+ llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
40
+
41
+ template = """
42
+ You are an expert research assistant. You can access information about articles via your tool.
43
+ Use information ONLY from this tool. Do not invent or add any more knowladge, be strict for the articles.
44
+ {instuction}
45
+
46
+ User: {user}
47
+ --------
48
+ {content}
49
+ """
50
+
51
+ prompt = PromptTemplate(
52
+ input_variables=["instuction", "user", "content"],
53
+ template=template,
54
+ )
55
+
56
+ chain = LLMChain(llm=llm, prompt=prompt, callbacks=[QueueCallback])
57
+
58
+ pinecone.init(
59
+ api_key=PINECONE_API_KEY,
60
+ environment=PINECONE_ENV
61
+ )
62
+
63
+ index = pinecone.Index(PINECONE_INDEX)
64
+ embedder = OpenAIEmbeddings()
65
+
66
+
67
+ class PineconeSearch:
68
+ docsearch
69
+ topk
70
+
71
+ def __init__(
72
+ namespace,
73
+ topk
74
+ ):
75
+ self.docsearch = Pinecone.from_existing_index(PINECONE_INDEX, embedder, namespace=namespace)
76
+ self.topk=topk
77
+
78
+ def __call__(query):
79
+ response = self.docsearch.similarity_search(query=query, k=self.topk)
80
+ context = ""
81
+ for doc in docs:
82
+ context += f"Coontent:\n{doc.page_content}\n"
83
+ context += f"Source: {doc.metadta['url']}\n"
84
+ contect += "----"
85
+ return context
86
+
87
+
88
+
89
+ def query_tool(category, pinecone_topk, query):
90
+ data = {
91
+ "1_D3_receptor": "demo-richter-target-400-30-1",
92
+ "2_dopamine": "demo-richter-target-400-30-2",
93
+ "3_mitochondrial": "demo-richter-target-400-30-3"
94
+ }
95
+
96
+ pinecone_namespace = data[category]
97
+
98
+ search_tool = PineconeSearch(
99
+ namespace=pinecone_namespace,
100
+ topk=pinecone_topk,
101
+ )
102
+
103
+ return search_tool(query)
104
+
105
+
106
+
107
+ def print_token_and_price(response):
108
+ inp = sum(response["token_usage"]["prompt_tokens"])
109
+ out = sum( response["token_usage"]["completion_tokens"])
110
+ print(f"Token usage: {inp+out}")
111
+ price = inp/1000*0.01 + out/1000*0.03
112
+ print(f"Total price: {price*370:.2f} Ft")
113
+ print("===================================")
114
+
115
+
116
+
117
+ def stream(input_text, history, user_prompt, topic, topk) -> Generator:
118
+ # Create a Queue
119
+ q = Queue()
120
+ job_done = object()
121
+
122
+ # Create a funciton to call - this will run in a thread
123
+ def task():
124
+ tool_resp = query_tool(topic, topk, input_text)
125
+
126
+ response = chain({"instuction": user_prompt, "user": input_text, "content": tool_resp})
127
+
128
+ #print_token_and_price(response=response)
129
+ q.put(job_done)
130
+
131
+ # Create a thread and start the function
132
+ t = Thread(target=task)
133
+ t.start()
134
+
135
+ content = ""
136
+
137
+ # Get each new token from the queue and yield for our generator
138
+ counter = 0
139
+ while True:
140
+ try:
141
+ next_token = q.get(True, timeout=1)
142
+ if next_token is job_done:
143
+ break
144
+ content += next_token
145
+ counter += 1
146
+ if counter == 20:
147
+ content += "\n"
148
+ counter = 0
149
+ if "\n" in next_token:
150
+ counter = 0
151
+ yield next_token, content
152
+ except Empty:
153
+ continue
154
+
155
+ def ask_llm(message, history, prompt, topic, topk):
156
+ for next_token, content in stream(message, history, prompt, topic, topk):
157
+ yield(content)
158
+
159
+
160
+ agent_prompt_textbox = gr.Textbox(
161
+ label = "Set the behaviour of the agent",
162
+ lines = 2,
163
+ value = "Make your brief answer in bullet points."
164
+ )
165
+ namespace_drobdown = gr.Dropdown(
166
+ ["1_D3_receptor", "2_dopamine", "3_mitochondrial"],
167
+ label="Choose a topic",
168
+ value="1_D3_receptor"
169
+ )
170
+ topk_slider = gr.Slider(
171
+ minimum=10,
172
+ maximum=350,
173
+ value=70,
174
+ step=10
175
+ )
176
+
177
+
178
+ additional_inputs = [agent_prompt_textbox, namespace_drobdown, topk_slider]
179
+
180
+ chatInterface = gr.ChatInterface(
181
+ fn=ask_llm,
182
+ additional_inputs=additional_inputs,
183
+ additional_inputs_accordion_name="Agent parameters"
184
+ ).queue().launch()