from modal import Stub, Image, Secret, asgi_app, method from urllib.request import urlretrieve from fastapi import FastAPI from typing import List, Dict image = Image.debian_slim("3.11").pip_install( "cohere", "gradio==3.50.2", "pinecone-client", ) stub = Stub("secsplorer", image=image) web_app = FastAPI() @stub.function( secrets=[Secret.from_name("cohere-api-key"), Secret.from_name("pinecone-api-key")] ) @asgi_app() def fastapi_app(): import cohere import pinecone import os import uuid import gradio as gr from gradio.routes import mount_gradio_app print("Connecting to cohere client") co = cohere.Client(os.environ["COHERE_API_KEY"]) print("Done") pinecone.init(api_key=os.environ["PINECONE_API_KEY"], environment="us-west1-gcp") index = pinecone.Index(index_name="td-sec-embeddings") def retrieve( index: pinecone.Index, query: str, co: cohere.Client ) -> List[Dict[str, str]]: """ Retrieves documents based on the given query. Parameters: query (str): The query to retrieve documents for. Returns: List[Dict[str, str]]: A list of dictionaries representing the retrieved documents, with 'title', 'snippet', and 'url' keys. """ docs_retrieved = [] print(f"Calling retrieve for '{query}'") print("Embedding the query") query_emb = co.embed( texts=[query], model="embed-english-v3.0", input_type="search_query" ).embeddings print("Querying pinecone") res = index.query(query_emb, top_k=10, include_metadata=True) print("Preparing to rerank") docs_to_rerank = [match["metadata"] for match in res["matches"]] rerank_results = co.rerank( query=query, documents=docs_to_rerank, top_n=3, model="rerank-english-v2.0", ) docs_retrieved = [] for hit in rerank_results: docs_retrieved.append(docs_to_rerank[hit.index]) print("Returning retrieved docs") return docs_retrieved class Chatbot: def __init__(self, co: cohere.Client, index: pinecone.Index): self.index = index self.conversation_id = str(uuid.uuid4()) self.co = co def generate_response(self, message: str): """ Generates a response to the user's message. Parameters: message (str): The user's message. Yields: Event: A response event generated by the chatbot. Returns: List[Dict[str, str]]: A list of dictionaries representing the retrieved documents. """ # Generate search queries (if any) response = self.co.chat(message=message, search_queries_only=True) # If there are search queries, retrieve documents and respond if response.search_queries: print("Retrieving information") documents = self.retrieve_docs(response) response = self.co.chat( message=message, documents=documents, conversation_id=self.conversation_id, stream=True, ) for event in response: yield event # If there is no search query, directly respond else: response = self.co.chat( message=message, conversation_id=self.conversation_id, stream=True ) for event in response: yield event def retrieve_docs(self, response) -> List[Dict[str, str]]: """ Retrieves documents based on the search queries in the response. Parameters: response: The response object containing search queries. Returns: List[Dict[str, str]]: A list of dictionaries representing the retrieved documents. """ # Get the query(s) queries = [] for search_query in response.search_queries: queries.append(search_query["text"]) # Retrieve documents for each query retrieved_docs = [] for query in queries: retrieved_docs.extend(retrieve(self.index, query, self.co)) return retrieved_docs chatbot = Chatbot(co, index) def chat_function(message, history): flag = False reply = "" for event in chatbot.generate_response(message): if event.event_type == "text-generation": reply += str(event.text) yield reply # Citations if event.event_type == "citation-generation": if not flag: reply += "\n\nCITATIONS:\n\n" yield reply flag = True reply += str(event.citations) + "\n" yield reply interface = gr.ChatInterface(chat_function).queue() print("All ready!") return mount_gradio_app(app=web_app, blocks=interface, path="/")