import cohere import os import pinecone import uuid from typing import List, Dict from dotenv import load_dotenv load_dotenv() co = cohere.Client(os.environ["COHERE_API_KEY"]) pc = pinecone.Pinecone(api_key=os.environ["PINECONE_API_KEY"]) index = pc.Index("td-sec-embeddings") def retrieve(index: pinecone.Index, query: str) -> List[Dict[str, str]]: """ Retrieves documents based on the given query. Parameters: query (str): The query to retrieve documents for. Returns: List[Dict[str, str]]: A list of dictionaries representing the retrieved documents, with 'title', 'snippet', and 'url' keys. """ docs_retrieved = [] query_emb = co.embed( texts=[query], model="embed-english-v3.0", input_type="search_query" ).embeddings res = index.query(vector=query_emb, top_k=100, include_metadata=True) docs_to_rerank = [match["metadata"] for match in res["matches"]] rerank_results = co.rerank( query=query, documents=docs_to_rerank, top_n=5, model="rerank-english-v2.0", ) docs_retrieved = [] for hit in rerank_results.results: docs_retrieved.append(docs_to_rerank[hit.index]) return docs_retrieved class Chatbot: def __init__(self, co: cohere.Client, index: pinecone.Index): self.index = index self.conversation_id = str(uuid.uuid4()) self.co = co self.docs = None def send_initial_instructions(self): response = self.co.chat_stream( message="""You are an expert in TD Bank's annual reports and have access to the 2023 and 2022 annual report. Respond with a polite welcome message.""", conversation_id=self.conversation_id, ) return response def generate_response(self, message: str): """ Generates a response to the user's message. Parameters: message (str): The user's message. Yields: Event: A response event generated by the chatbot. Returns: List[Dict[str, str]]: A list of dictionaries representing the retrieved documents. """ # Generate search queries (if any) response = self.co.chat( message=message, search_queries_only=True, conversation_id=self.conversation_id, ) # If there are search queries, retrieve documents and respond if response.search_queries: print("Retrieving information...") documents = self.retrieve_docs(response) self.docs = {f"doc_{i}": document for i, document in enumerate(documents)} response = self.co.chat_stream( message=message, documents=documents, conversation_id=self.conversation_id, ) for event in response: yield event # If there is no search query, directly respond else: response = self.co.chat_stream( message=message, conversation_id=self.conversation_id, ) for event in response: yield event def retrieve_docs(self, response) -> List[Dict[str, str]]: """ Retrieves documents based on the search queries in the response. Parameters: response: The response object containing search queries. Returns: List[Dict[str, str]]: A list of dictionaries representing the retrieved documents. """ # Get the query(s) queries = [] for search_query in response.search_queries: queries.append(search_query.text) # Retrieve documents for each query retrieved_docs = [] for query in queries: retrieved_docs.extend(retrieve(self.index, query)) return retrieved_docs import gradio as gr with gr.Blocks() as demo: chatbot = gr.Chatbot() msg = gr.Textbox() clear = gr.Button("Clear") cohere_chatbot_var = gr.State() def user(user_message, history): return "", history + [[user_message, None]] def chat_function(history, cohere_chatbot): if cohere_chatbot is None: cohere_chatbot = Chatbot(co, index) response = cohere_chatbot.send_initial_instructions() history = [[None, ""]] for event in response: if event.event_type == "text-generation": history[0][1] += str(event.text) yield history, cohere_chatbot return message = history[-1][0] history[-1][1] = "" documents_used = set() flag = True for event in cohere_chatbot.generate_response(message): if event.event_type == "text-generation": history[-1][1] += str(event.text) yield history, cohere_chatbot # Citations if event.event_type == "citation-generation": if flag: history[-1][1] += "\n\n**DOCUMENTS CONSULTED:**\n\n" yield history, cohere_chatbot flag = False for citation in event.citations: documents_used.update(citation.document_ids) urls_used = set(cohere_chatbot.docs[doc_id]["url"] for doc_id in documents_used) for url in sorted(urls_used): history[-1][1] += f"* {url}\n" yield history, cohere_chatbot # Make sure we run the thing once to initialize! demo.load( chat_function, [chatbot, cohere_chatbot_var], [chatbot, cohere_chatbot_var] ) msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( chat_function, [chatbot, cohere_chatbot_var], [chatbot, cohere_chatbot_var] ) clear.click(lambda: None, None, chatbot, queue=False) demo.queue() demo.launch()