secsplorer / modal_script.py
lagerbaer's picture
Upload folder using huggingface_hub
99e964c
from modal import Stub, Image, Secret, asgi_app, method
from urllib.request import urlretrieve
from fastapi import FastAPI
from typing import List, Dict
image = Image.debian_slim("3.11").pip_install(
"cohere",
"gradio==3.50.2",
"pinecone-client",
)
stub = Stub("secsplorer", image=image)
web_app = FastAPI()
@stub.function(
secrets=[Secret.from_name("cohere-api-key"), Secret.from_name("pinecone-api-key")]
)
@asgi_app()
def fastapi_app():
import cohere
import pinecone
import os
import uuid
import gradio as gr
from gradio.routes import mount_gradio_app
print("Connecting to cohere client")
co = cohere.Client(os.environ["COHERE_API_KEY"])
print("Done")
pinecone.init(api_key=os.environ["PINECONE_API_KEY"], environment="us-west1-gcp")
index = pinecone.Index(index_name="td-sec-embeddings")
def retrieve(
index: pinecone.Index, query: str, co: cohere.Client
) -> List[Dict[str, str]]:
"""
Retrieves documents based on the given query.
Parameters:
query (str): The query to retrieve documents for.
Returns:
List[Dict[str, str]]: A list of dictionaries representing the retrieved documents, with 'title', 'snippet', and 'url' keys.
"""
docs_retrieved = []
print(f"Calling retrieve for '{query}'")
print("Embedding the query")
query_emb = co.embed(
texts=[query], model="embed-english-v3.0", input_type="search_query"
).embeddings
print("Querying pinecone")
res = index.query(query_emb, top_k=10, include_metadata=True)
print("Preparing to rerank")
docs_to_rerank = [match["metadata"] for match in res["matches"]]
rerank_results = co.rerank(
query=query,
documents=docs_to_rerank,
top_n=3,
model="rerank-english-v2.0",
)
docs_retrieved = []
for hit in rerank_results:
docs_retrieved.append(docs_to_rerank[hit.index])
print("Returning retrieved docs")
return docs_retrieved
class Chatbot:
def __init__(self, co: cohere.Client, index: pinecone.Index):
self.index = index
self.conversation_id = str(uuid.uuid4())
self.co = co
def generate_response(self, message: str):
"""
Generates a response to the user's message.
Parameters:
message (str): The user's message.
Yields:
Event: A response event generated by the chatbot.
Returns:
List[Dict[str, str]]: A list of dictionaries representing the retrieved documents.
"""
# Generate search queries (if any)
response = self.co.chat(message=message, search_queries_only=True)
# If there are search queries, retrieve documents and respond
if response.search_queries:
print("Retrieving information")
documents = self.retrieve_docs(response)
response = self.co.chat(
message=message,
documents=documents,
conversation_id=self.conversation_id,
stream=True,
)
for event in response:
yield event
# If there is no search query, directly respond
else:
response = self.co.chat(
message=message, conversation_id=self.conversation_id, stream=True
)
for event in response:
yield event
def retrieve_docs(self, response) -> List[Dict[str, str]]:
"""
Retrieves documents based on the search queries in the response.
Parameters:
response: The response object containing search queries.
Returns:
List[Dict[str, str]]: A list of dictionaries representing the retrieved documents.
"""
# Get the query(s)
queries = []
for search_query in response.search_queries:
queries.append(search_query["text"])
# Retrieve documents for each query
retrieved_docs = []
for query in queries:
retrieved_docs.extend(retrieve(self.index, query, self.co))
return retrieved_docs
chatbot = Chatbot(co, index)
def chat_function(message, history):
flag = False
reply = ""
for event in chatbot.generate_response(message):
if event.event_type == "text-generation":
reply += str(event.text)
yield reply
# Citations
if event.event_type == "citation-generation":
if not flag:
reply += "\n\nCITATIONS:\n\n"
yield reply
flag = True
reply += str(event.citations) + "\n"
yield reply
interface = gr.ChatInterface(chat_function).queue()
print("All ready!")
return mount_gradio_app(app=web_app, blocks=interface, path="/")