amagastya commited on
Commit
c153af6
·
verified ·
1 Parent(s): dbda290

Create app.py

Browse files
Files changed (1) hide show
  1. app/app.py +233 -0
app/app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import uuid
4
+
5
+ import chainlit as cl
6
+ import cohere
7
+ import yaml
8
+ from chainlit import on_chat_start
9
+ from langchain_openai import OpenAIEmbeddings
10
+ from langchain_pinecone import PineconeVectorStore
11
+ from openai import AsyncOpenAI
12
+ from pinecone import Pinecone
13
+
14
+ #Set Up client and environment
15
+ client = AsyncOpenAI(api_key=os.environ['OPENAI_API_KEY'])
16
+ co = cohere.ClientV2(os.environ['COHERE_API_KEY'])
17
+
18
+ #Initialize embeddings & vectorstore
19
+ # embeddings = CohereEmbeddings(cohere_api_key=os.environ['COHERE_API_KEY'], model="embed-english-light-v3.0")
20
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
21
+
22
+ pc = Pinecone(
23
+ api_key=os.environ['PINECONE_API_KEY']
24
+ )
25
+
26
+ learn_index = pc.Index('sparklearn')
27
+ prompt_index = pc.Index('spark-prompts')
28
+
29
+ learnsearch = PineconeVectorStore(index=learn_index, embedding=embeddings)
30
+ promptsearch = PineconeVectorStore(index=prompt_index, embedding=embeddings)
31
+
32
+ learn_retriever = learnsearch.as_retriever(search_kwargs={"k": 8})
33
+ prompt_retriever = promptsearch.as_retriever(search_kwargs={"k": 8})
34
+
35
+ @cl.set_chat_profiles
36
+ async def chat_profile():
37
+ return [
38
+ cl.ChatProfile(
39
+ name="Learn Mode",
40
+ markdown_description="Use this mode to learn about prompt engineering.",
41
+ icon="https://www.shutterstock.com/image-vector/brain-emoji-vector-isolated-faces-600nw-2344535053.jpg",
42
+ ),
43
+ cl.ChatProfile(
44
+ name="Prompt Mode",
45
+ markdown_description="Use this mode to query the prompt database.",
46
+ icon="https://e7.pngegg.com/pngimages/296/768/png-clipart-emoji-memorandum-computer-icons-text-messaging-writing-writing-pencil-emoticon.png",
47
+ ),
48
+ ]
49
+
50
+ @on_chat_start
51
+ async def init():
52
+ conversation_id = str(uuid.uuid4())
53
+ cl.user_session.set("id", conversation_id)
54
+
55
+ # @traceable(run_type="chain")
56
+ @cl.on_message
57
+ async def main(message: cl.Message):
58
+ task_list = cl.TaskList()
59
+ task_list.status = "Running..."
60
+
61
+ mode = cl.user_session.get("chat_profile")
62
+
63
+ # Create a task and put it in the running state
64
+ task1 = cl.Task(title="Generating Search Query", status=cl.TaskStatus.RUNNING)
65
+ await task_list.add_task(task1)
66
+ await task_list.send()
67
+
68
+ # Add 'running' loader in UI
69
+ msg = cl.Message(content="")
70
+ await msg.send()
71
+ await cl.sleep(00000000000.1)
72
+ # Call Cohere chat query gen mode
73
+ try:
74
+ instructions = (
75
+ "Context: You are part of a Retrieval Augmented Generation (RAG) Conversational QA system. You are the search query generator. Generate a single search query that accurately reflects the user's intent. "
76
+ "The output should simply be a search query, without any additional information or lists."
77
+ )
78
+
79
+ # Generate search queries
80
+ search_queries = []
81
+
82
+
83
+ # Define the query generation tool
84
+ # query_gen_tool = [
85
+ # {
86
+ # "type": "function",
87
+ # "function": {
88
+ # "name": "internet_search",
89
+ # "description": "Returns a list of relevant document snippets for a textual query retrieved from the internet",
90
+ # "parameters": {
91
+ # "type": "object",
92
+ # "properties": {
93
+ # "queries": {
94
+ # "type": "array",
95
+ # "items": {"type": "string"},
96
+ # "description": "a list of queries to search the internet with.",
97
+ # }
98
+ # },
99
+ # "required": ["queries"],
100
+ # },
101
+ # },
102
+ # }
103
+ # ]
104
+
105
+ res = co.chat(
106
+ model="command-a-03-2025",
107
+ messages=[
108
+ {"role": "system", "content": instructions},
109
+ {"role": "user", "content": message.content}, # Use message.content instead of message
110
+ ],
111
+ # tools=query_gen_tool
112
+ )
113
+ print("search query", res)
114
+
115
+ search_query = res.message.content[0].text if res.message.content else message.content
116
+
117
+ id = await msg.send()
118
+ await task_list.add_task(cl.Task(title=f"Generated Search Query: {search_query}", status=cl.TaskStatus.DONE))
119
+
120
+ if res.message.tool_calls:
121
+ for tc in res.message.tool_calls:
122
+ queries = json.loads(tc.function.arguments)["queries"]
123
+ search_queries.extend(queries)
124
+ print(search_queries)
125
+ except Exception as e:
126
+ print(f"Error generating search query: {e}")
127
+ search_query = message.content
128
+ task1.status = cl.TaskStatus.DONE
129
+ await task_list.send()
130
+
131
+ task2 = cl.Task(title="Retrieving Contexts", status=cl.TaskStatus.RUNNING)
132
+ await task_list.add_task(task2)
133
+ await task_list.send()
134
+
135
+ # Set retriever based on mode
136
+ if mode == "Learn Mode":
137
+ retriever = learn_retriever
138
+ elif mode == "Prompt Mode":
139
+ retriever = prompt_retriever
140
+
141
+ retrieved = retriever.invoke(search_query)
142
+ task2.status = cl.TaskStatus.DONE
143
+ await task_list.send()
144
+
145
+ # print('retrieved', retrieved)
146
+
147
+
148
+ urls = list(set([d.metadata['source'] for d in retrieved]))
149
+ if mode == "Learn Mode":
150
+ docs = [{"Title": d.metadata['title'], "Content": d.page_content} for i, d in enumerate(retrieved)]
151
+ else:
152
+ docs = [{"Content": d.page_content} for i, d in enumerate(retrieved)]
153
+
154
+
155
+ yaml_docs = [yaml.dump(doc, sort_keys=False) for doc in docs]
156
+
157
+
158
+ task3 = cl.Task(title="Re-Ranking Results", status=cl.TaskStatus.RUNNING)
159
+ await task_list.add_task(task3)
160
+ await task_list.send()
161
+
162
+ # Rerank the top results
163
+ reranked = co.rerank(model="rerank-v3.5", query=search_query, documents=yaml_docs, top_n=5)
164
+
165
+ reranked_docs = [
166
+ {
167
+ "data": {
168
+ "title": docs[result.index]["Title"] if mode == "Learn Mode" else None,
169
+ "snippet": docs[result.index]["Content"],
170
+ }
171
+ }
172
+ for result in reranked.results
173
+ ]
174
+ # print("Rereanked", reranked_docs)
175
+ task3.status = cl.TaskStatus.DONE
176
+ await task_list.send()
177
+
178
+ # Generate final response stream with cohere chat
179
+ task4 = cl.Task(title="Generating Response", status=cl.TaskStatus.RUNNING)
180
+ await task_list.add_task(task4)
181
+ await task_list.send()
182
+ try:
183
+ # Define the messages list with the preamble and user message
184
+ messages = [
185
+ {"role": "system", "content": (
186
+ "You are SPARK, a Prompt Assistant created by Conversational AI Expert - Amogh Agastya (https://amagastya.com)."
187
+ "SPARK stands for Smart Prompt Assistant and Resource Knowledgebase. SPARK exudes a friendly and knowledgeable persona,"
188
+ "designed to be a reliable and trustworthy guide in the world of prompt engineering."
189
+ "There are two modes: 'Learn Mode' for generating informative responses and 'Prompt Mode' for crafting prompts."
190
+ "In 'Prompt Mode', SPARK helps generate prompts for users based on their queries. It provides relevant information and resources to assist them in crafting effective prompts."
191
+ "Additionally, SPARK in prompt mode can chat with the user to clarify and craft the best prompt for their objectiive. You can also provide reasoning behind the crafted prompt."
192
+ f"The user is currently on mode {mode}"
193
+ )},
194
+ {"role": "user", "content": message.content}
195
+ ]
196
+
197
+ stream = co.chat_stream(
198
+ model="command-a-03-2025",
199
+ messages=messages,
200
+ documents=reranked_docs,
201
+ )
202
+
203
+ response_text = ""
204
+ citations = []
205
+ for chunk in stream:
206
+ if chunk:
207
+ if chunk.type == "content-delta":
208
+ response_text += chunk.delta.message.content.text
209
+ # print(chunk.delta.message.content.text, end="")
210
+ await msg.stream_token(chunk.delta.message.content.text)
211
+ elif chunk.type == "citation-start":
212
+ citations.append(chunk.delta.message.citations)
213
+
214
+ task4.status = cl.TaskStatus.DONE
215
+ await task_list.send()
216
+
217
+
218
+ except Exception as e:
219
+ print(f"Error generating response: {e}")
220
+
221
+ if mode != "Prompt Mode": # Only display sources if not in prompt mode
222
+ if mode == "Learn Mode":
223
+ sources = "\n".join([f"- {url}" for url in urls])
224
+ else:
225
+ sources = "\n\n".join([doc['data']['snippet'] for doc in reranked_docs]) # Adjusted to match new structure
226
+
227
+ await cl.Message(content=f"*Sources*:\n\n{sources}", parent_id=id).send()
228
+
229
+ task4.status = cl.TaskStatus.DONE
230
+ await task_list.send()
231
+
232
+ task_list.status = "Completed Successfully"
233
+ await task_list.send()