Spaces:
Runtime error
Runtime error
| import cohere | |
| import os | |
| import pinecone | |
| import uuid | |
| from typing import List, Dict | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| co = cohere.Client(os.environ["COHERE_API_KEY"]) | |
| pc = pinecone.Pinecone(api_key=os.environ["PINECONE_API_KEY"]) | |
| index = pc.Index("td-sec-embeddings") | |
| def retrieve(index: pinecone.Index, query: str) -> List[Dict[str, str]]: | |
| """ | |
| Retrieves documents based on the given query. | |
| Parameters: | |
| query (str): The query to retrieve documents for. | |
| Returns: | |
| List[Dict[str, str]]: A list of dictionaries representing the retrieved documents, with 'title', 'snippet', and 'url' keys. | |
| """ | |
| docs_retrieved = [] | |
| query_emb = co.embed( | |
| texts=[query], model="embed-english-v3.0", input_type="search_query" | |
| ).embeddings | |
| res = index.query(vector=query_emb, top_k=100, include_metadata=True) | |
| docs_to_rerank = [match["metadata"] for match in res["matches"]] | |
| rerank_results = co.rerank( | |
| query=query, | |
| documents=docs_to_rerank, | |
| top_n=5, | |
| model="rerank-english-v2.0", | |
| ) | |
| docs_retrieved = [] | |
| for hit in rerank_results.results: | |
| docs_retrieved.append(docs_to_rerank[hit.index]) | |
| return docs_retrieved | |
| class Chatbot: | |
| def __init__(self, co: cohere.Client, index: pinecone.Index): | |
| self.index = index | |
| self.conversation_id = str(uuid.uuid4()) | |
| self.co = co | |
| self.docs = None | |
| def send_initial_instructions(self): | |
| response = self.co.chat_stream( | |
| message="""You are an expert in TD Bank's annual reports and have access to the 2023 and 2022 annual report. Respond with a polite welcome message.""", | |
| conversation_id=self.conversation_id, | |
| ) | |
| return response | |
| def generate_response(self, message: str): | |
| """ | |
| Generates a response to the user's message. | |
| Parameters: | |
| message (str): The user's message. | |
| Yields: | |
| Event: A response event generated by the chatbot. | |
| Returns: | |
| List[Dict[str, str]]: A list of dictionaries representing the retrieved documents. | |
| """ | |
| # Generate search queries (if any) | |
| response = self.co.chat( | |
| message=message, | |
| search_queries_only=True, | |
| conversation_id=self.conversation_id, | |
| ) | |
| # If there are search queries, retrieve documents and respond | |
| if response.search_queries: | |
| print("Retrieving information...") | |
| documents = self.retrieve_docs(response) | |
| self.docs = {f"doc_{i}": document for i, document in enumerate(documents)} | |
| response = self.co.chat_stream( | |
| message=message, | |
| documents=documents, | |
| conversation_id=self.conversation_id, | |
| ) | |
| for event in response: | |
| yield event | |
| # If there is no search query, directly respond | |
| else: | |
| response = self.co.chat_stream( | |
| message=message, | |
| conversation_id=self.conversation_id, | |
| ) | |
| for event in response: | |
| yield event | |
| def retrieve_docs(self, response) -> List[Dict[str, str]]: | |
| """ | |
| Retrieves documents based on the search queries in the response. | |
| Parameters: | |
| response: The response object containing search queries. | |
| Returns: | |
| List[Dict[str, str]]: A list of dictionaries representing the retrieved documents. | |
| """ | |
| # Get the query(s) | |
| queries = [] | |
| for search_query in response.search_queries: | |
| queries.append(search_query.text) | |
| # Retrieve documents for each query | |
| retrieved_docs = [] | |
| for query in queries: | |
| retrieved_docs.extend(retrieve(self.index, query)) | |
| return retrieved_docs | |
| import gradio as gr | |
| with gr.Blocks() as demo: | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox() | |
| clear = gr.Button("Clear") | |
| cohere_chatbot_var = gr.State() | |
| def user(user_message, history): | |
| return "", history + [[user_message, None]] | |
| def chat_function(history, cohere_chatbot): | |
| if cohere_chatbot is None: | |
| cohere_chatbot = Chatbot(co, index) | |
| response = cohere_chatbot.send_initial_instructions() | |
| history = [[None, ""]] | |
| for event in response: | |
| if event.event_type == "text-generation": | |
| history[0][1] += str(event.text) | |
| yield history, cohere_chatbot | |
| return | |
| message = history[-1][0] | |
| history[-1][1] = "" | |
| documents_used = set() | |
| flag = True | |
| for event in cohere_chatbot.generate_response(message): | |
| if event.event_type == "text-generation": | |
| history[-1][1] += str(event.text) | |
| yield history, cohere_chatbot | |
| # Citations | |
| if event.event_type == "citation-generation": | |
| if flag: | |
| history[-1][1] += "\n\n**DOCUMENTS CONSULTED:**\n\n" | |
| yield history, cohere_chatbot | |
| flag = False | |
| for citation in event.citations: | |
| documents_used.update(citation.document_ids) | |
| urls_used = set(cohere_chatbot.docs[doc_id]["url"] for doc_id in documents_used) | |
| for url in sorted(urls_used): | |
| history[-1][1] += f"* {url}\n" | |
| yield history, cohere_chatbot | |
| # Make sure we run the thing once to initialize! | |
| demo.load( | |
| chat_function, [chatbot, cohere_chatbot_var], [chatbot, cohere_chatbot_var] | |
| ) | |
| msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
| chat_function, [chatbot, cohere_chatbot_var], [chatbot, cohere_chatbot_var] | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.queue() | |
| demo.launch() | |