finchat / app.py
Monsia's picture
first commit
c4331f2
raw
history blame
3.86 kB
import chainlit as cl
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough
from langchain.vectorstores.chroma import Chroma
from langchain_google_genai import (
GoogleGenerativeAI,
GoogleGenerativeAIEmbeddings,
HarmBlockThreshold,
HarmCategory,
)
import config
from prompts import prompt
metadata_field_info = [
AttributeInfo(
name="title",
description="Le titre de l'article",
type="string",
),
AttributeInfo(
name="date",
description="Date de publication",
type="string",
),
AttributeInfo(name="link", description="Source de l'article", type="string"),
]
document_content_description = "Articles sur l'actualité."
model = GoogleGenerativeAI(
model=config.GOOGLE_CHAT_MODEL,
google_api_key=config.GOOGLE_API_KEY,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
},
) # type: ignore
# Load vector database that was persisted earlier
embedding = embeddings_model = GoogleGenerativeAIEmbeddings(
model="models/embedding-001", google_api_key=config.GOOGLE_API_KEY
) # type: ignore
vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding)
retriever = SelfQueryRetriever.from_llm(
model,
vectordb,
document_content_description,
metadata_field_info,
)
@cl.on_chat_start
async def on_chat_start():
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{
"context": vectordb.as_retriever() | format_docs,
"question": RunnablePassthrough(),
}
| prompt
| model
| StrOutputParser()
)
cl.user_session.set("rag_chain", rag_chain)
msg = cl.Message(
content=f"Vous pouvez poser vos questions sur les articles de SIKAFINANCE",
)
await msg.send()
@cl.on_message
async def on_message(message: cl.Message):
runnable = cl.user_session.get("rag_chain") # type: Runnable # type: ignore
msg = cl.Message(content="")
class PostMessageHandler(BaseCallbackHandler):
"""
Callback handler for handling the retriever and LLM processes.
Used to post the sources of the retrieved documents as a Chainlit element.
"""
def __init__(self, msg: cl.Message):
BaseCallbackHandler.__init__(self)
self.msg = msg
self.sources = []
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
for d in documents:
source_doc = d.page_content + "\nSource: " + d.metadata["link"]
self.sources.append(source_doc)
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
if len(self.sources):
# Display the reference docs with a Text widget
sources_element = [
cl.Text(name=f"source_{idx+1}", content=content)
for idx, content in enumerate(self.sources)
]
source_names = [el.name for el in sources_element]
self.msg.elements += sources_element
self.msg.content += f"\nSources: {', '.join(source_names)}"
async with cl.Step(type="run", name="QA Assistant"):
async for chunk in runnable.astream(
message.content,
config=RunnableConfig(
callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
),
):
await msg.stream_token(chunk)
await msg.send()