| import chainlit as cl |
| from langchain.callbacks.base import BaseCallbackHandler |
| from langchain.chains.query_constructor.schema import AttributeInfo |
| from langchain.retrievers.self_query.base import SelfQueryRetriever |
| from langchain.schema import StrOutputParser |
| from langchain.schema.runnable import Runnable, RunnableConfig, RunnablePassthrough |
| from langchain.vectorstores.chroma import Chroma |
| from langchain_google_genai import ( |
| GoogleGenerativeAI, |
| GoogleGenerativeAIEmbeddings, |
| HarmBlockThreshold, |
| HarmCategory, |
| ) |
|
|
| import config |
| from prompts import prompt |
|
|
| metadata_field_info = [ |
| AttributeInfo( |
| name="title", |
| description="Le titre de l'article", |
| type="string", |
| ), |
| AttributeInfo( |
| name="date", |
| description="Date de publication", |
| type="string", |
| ), |
| AttributeInfo(name="link", description="Source de l'article", type="string"), |
| ] |
| document_content_description = "Articles sur l'actualité." |
|
|
| model = GoogleGenerativeAI( |
| model=config.GOOGLE_CHAT_MODEL, |
| google_api_key=config.GOOGLE_API_KEY, |
| safety_settings={ |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, |
| }, |
| ) |
|
|
| |
| embedding = embeddings_model = GoogleGenerativeAIEmbeddings( |
| model="models/embedding-001", google_api_key=config.GOOGLE_API_KEY |
| ) |
|
|
| vectordb = Chroma(persist_directory=config.STORAGE_PATH, embedding_function=embedding) |
|
|
| retriever = SelfQueryRetriever.from_llm( |
| model, |
| vectordb, |
| document_content_description, |
| metadata_field_info, |
| ) |
|
|
|
|
| @cl.on_chat_start |
| async def on_chat_start(): |
|
|
| def format_docs(docs): |
| return "\n\n".join(doc.page_content for doc in docs) |
|
|
| rag_chain = ( |
| { |
| "context": vectordb.as_retriever() | format_docs, |
| "question": RunnablePassthrough(), |
| } |
| | prompt |
| | model |
| | StrOutputParser() |
| ) |
|
|
| cl.user_session.set("rag_chain", rag_chain) |
|
|
| msg = cl.Message( |
| content=f"Vous pouvez poser vos questions sur les articles de SIKAFINANCE", |
| ) |
| await msg.send() |
|
|
|
|
| @cl.on_message |
| async def on_message(message: cl.Message): |
| runnable = cl.user_session.get("rag_chain") |
| msg = cl.Message(content="") |
|
|
| class PostMessageHandler(BaseCallbackHandler): |
| """ |
| Callback handler for handling the retriever and LLM processes. |
| Used to post the sources of the retrieved documents as a Chainlit element. |
| """ |
|
|
| def __init__(self, msg: cl.Message): |
| BaseCallbackHandler.__init__(self) |
| self.msg = msg |
| self.sources = [] |
|
|
| def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): |
| for d in documents: |
| source_doc = d.page_content + "\nSource: " + d.metadata["link"] |
| self.sources.append(source_doc) |
|
|
| def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): |
| if len(self.sources): |
| |
| sources_element = [ |
| cl.Text(name=f"source_{idx+1}", content=content) |
| for idx, content in enumerate(self.sources) |
| ] |
| source_names = [el.name for el in sources_element] |
| self.msg.elements += sources_element |
| self.msg.content += f"\nSources: {', '.join(source_names)}" |
|
|
| async with cl.Step(type="run", name="QA Assistant"): |
| async for chunk in runnable.astream( |
| message.content, |
| config=RunnableConfig( |
| callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)] |
| ), |
| ): |
| await msg.stream_token(chunk) |
|
|
| await msg.send() |
|
|