from typing import Any import gradio as gr from langchain.chat_models import init_chat_model from langchain_core.tools import tool from langchain_aws import BedrockEmbeddings from langchain_qdrant import QdrantVectorStore from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import MessagesState, StateGraph, END from langgraph.prebuilt import ToolNode, tools_condition import structlog from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams import logging_config as _ # from conversation.main import graph from conversation.generate import generate, trigger_ai_message_with_tool_call from conversation.chat_source_history import prettify_source_history, build_source_history_object, clear_chat_and_source_history, create_chat_hash from ingestion.main import ingest_document from tools.langfuse_client import get_langfuse_handler from config import app_settings # Get Langfuse Callback handler langfuse_handler = get_langfuse_handler() # Create a logger instance logger = structlog.get_logger(__name__) embeddings = BedrockEmbeddings( model_id=app_settings.embedding_model, region_name=app_settings.llm_region, aws_access_key_id=app_settings.aws_access_key_id, aws_secret_access_key=app_settings.aws_secret_access_key, ) llm = init_chat_model( app_settings.llm_model, model_provider=app_settings.model_provider, region_name=app_settings.llm_region, aws_access_key_id=app_settings.aws_access_key_id, aws_secret_access_key=app_settings.aws_secret_access_key, ) client = QdrantClient(app_settings.vector_db_url) if not client.collection_exists(app_settings.vector_db_collection_name): client.create_collection( collection_name=app_settings.vector_db_collection_name, vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE), sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)} ) # TODO: move to LLM files later vector_store = QdrantVectorStore( client=client, collection_name=app_settings.vector_db_collection_name, embedding=embeddings, ) # ------ # Move to `conversation/main`` later @tool(response_format="content_and_artifact") def retrieve(query: str): """Retrieve information related to a query.""" retrieved_docs = vector_store.similarity_search(query, k=10) serialized = "\n\n".join( (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}") for doc in retrieved_docs ) return serialized, retrieved_docs graph_builder = StateGraph(MessagesState) memory = MemorySaver() tool_node = ToolNode([retrieve]) graph_builder.add_node(trigger_ai_message_with_tool_call) graph_builder.add_node("tool_node", tool_node) graph_builder.add_node(generate) graph_builder.set_entry_point("trigger_ai_message_with_tool_call") graph_builder.add_edge("trigger_ai_message_with_tool_call", "tool_node") graph_builder.add_edge("tool_node", "generate") graph_builder.add_edge("generate", END) graph = graph_builder.compile(checkpointer=memory) # ----- def bot(message, history, source_history, chat_hash) -> list[Any]: """Generate bot response and history from message. With multi-modal inputs text and each file is treated as separate message. """ logger.info("This is the history", history=history) # enable message edit if isinstance(message, str): message = {"text": message} # process files for file in message.get("files"): logger.info("Received file", file=file) ingest_document(file, vector_store) # create text response user_input_prompt = message.get("text") logger.info("This is the current chat_hash", chat_hash=chat_hash) config = {"configurable": {"thread_id": chat_hash}, "callbacks": [langfuse_handler], "stream": False} response = graph.invoke( {"messages": [{"role": "user", "content": user_input_prompt}]}, config=config, ) logger.info("Generated a response", response=response) logger.info("This is the source history before appending response", source_history=source_history) # Append source details of the response to source history source_history = build_source_history_object(response, source_history, user_input_prompt) return [response["messages"][-1].content], source_history, chat_hash with gr.Blocks() as demo: # Initialize source history as session state variable source_history = gr.State([]) chat_hash = gr.State(create_chat_hash) gr.Markdown("### Dies ist der aktuelle Prototyp der DEval AI4Kontextanalysen, mit dem Fokus auf dem Testen der Zitationsfunktion und der grundlegenden Konfiguration.") with gr.Tab("Chat"): chatbot = gr.Chatbot(type="messages") chatbot.clear( clear_chat_and_source_history, inputs=[source_history, chat_hash], outputs=[source_history, chat_hash], ) gr.ChatInterface( bot, chatbot=chatbot, type="messages", flagging_mode="manual", editable=True, flagging_options=["Like", "Dislike", "Spam", "Inappropriate", "Other"], multimodal=True, textbox=gr.MultimodalTextbox(file_count="multiple"), title="DEval Prototype 1", additional_inputs=[source_history, chat_hash], additional_outputs=[source_history, chat_hash], ) with gr.Tab("Quellen"): prettified_source_history_md = gr.Markdown() # Update the source history UI immediately when the source_history value changes source_history.change(prettify_source_history, source_history, prettified_source_history_md) if __name__ == "__main__": demo.launch( auth=( app_settings.test_user_name, app_settings.test_user_password ), ssr_mode=False )