Julia Ostheimer
Potentially fix deployment issue: Remove server specs in Gradio launch method
c744766
| from typing import Any | |
| import gradio as gr | |
| from langchain.chat_models import init_chat_model | |
| from langchain_core.tools import tool | |
| from langchain_aws import BedrockEmbeddings | |
| from langchain_qdrant import QdrantVectorStore | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph import MessagesState, StateGraph, END | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| import structlog | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams | |
| import logging_config as _ | |
| # from conversation.main import graph | |
| from conversation.generate import generate, trigger_ai_message_with_tool_call | |
| from conversation.chat_source_history import prettify_source_history, build_source_history_object, clear_chat_and_source_history, create_chat_hash | |
| from ingestion.main import ingest_document | |
| from tools.langfuse_client import get_langfuse_handler | |
| from config import app_settings | |
| # Get Langfuse Callback handler | |
| langfuse_handler = get_langfuse_handler() | |
| # Create a logger instance | |
| logger = structlog.get_logger(__name__) | |
| embeddings = BedrockEmbeddings( | |
| model_id=app_settings.embedding_model, | |
| region_name=app_settings.llm_region, | |
| aws_access_key_id=app_settings.aws_access_key_id, | |
| aws_secret_access_key=app_settings.aws_secret_access_key, | |
| ) | |
| llm = init_chat_model( | |
| app_settings.llm_model, | |
| model_provider=app_settings.model_provider, | |
| region_name=app_settings.llm_region, | |
| aws_access_key_id=app_settings.aws_access_key_id, | |
| aws_secret_access_key=app_settings.aws_secret_access_key, | |
| ) | |
| client = QdrantClient(app_settings.vector_db_url) | |
| if not client.collection_exists(app_settings.vector_db_collection_name): | |
| client.create_collection( | |
| collection_name=app_settings.vector_db_collection_name, | |
| vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE), | |
| sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)} | |
| ) | |
| # TODO: move to LLM files later | |
| vector_store = QdrantVectorStore( | |
| client=client, | |
| collection_name=app_settings.vector_db_collection_name, | |
| embedding=embeddings, | |
| ) | |
| # ------ | |
| # Move to `conversation/main`` later | |
| def retrieve(query: str): | |
| """Retrieve information related to a query.""" | |
| retrieved_docs = vector_store.similarity_search(query, k=10) | |
| serialized = "\n\n".join( | |
| (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}") | |
| for doc in retrieved_docs | |
| ) | |
| return serialized, retrieved_docs | |
| graph_builder = StateGraph(MessagesState) | |
| memory = MemorySaver() | |
| tool_node = ToolNode([retrieve]) | |
| graph_builder.add_node(trigger_ai_message_with_tool_call) | |
| graph_builder.add_node("tool_node", tool_node) | |
| graph_builder.add_node(generate) | |
| graph_builder.set_entry_point("trigger_ai_message_with_tool_call") | |
| graph_builder.add_edge("trigger_ai_message_with_tool_call", "tool_node") | |
| graph_builder.add_edge("tool_node", "generate") | |
| graph_builder.add_edge("generate", END) | |
| graph = graph_builder.compile(checkpointer=memory) | |
| # ----- | |
| def bot(message, history, source_history, chat_hash) -> list[Any]: | |
| """Generate bot response and history from message. | |
| With multi-modal inputs text and each file is treated as separate message. | |
| """ | |
| logger.info("This is the history", history=history) | |
| # enable message edit | |
| if isinstance(message, str): | |
| message = {"text": message} | |
| # process files | |
| for file in message.get("files"): | |
| logger.info("Received file", file=file) | |
| ingest_document(file, vector_store) | |
| # create text response | |
| user_input_prompt = message.get("text") | |
| logger.info("This is the current chat_hash", chat_hash=chat_hash) | |
| config = {"configurable": {"thread_id": chat_hash}, "callbacks": [langfuse_handler], "stream": False} | |
| response = graph.invoke( | |
| {"messages": [{"role": "user", "content": user_input_prompt}]}, | |
| config=config, | |
| ) | |
| logger.info("Generated a response", response=response) | |
| logger.info("This is the source history before appending response", source_history=source_history) | |
| # Append source details of the response to source history | |
| source_history = build_source_history_object(response, source_history, user_input_prompt) | |
| return [response["messages"][-1].content], source_history, chat_hash | |
| with gr.Blocks() as demo: | |
| # Initialize source history as session state variable | |
| source_history = gr.State([]) | |
| chat_hash = gr.State(create_chat_hash) | |
| gr.Markdown("### Dies ist der aktuelle Prototyp der DEval AI4Kontextanalysen, mit dem Fokus auf dem Testen der Zitationsfunktion und der grundlegenden Konfiguration.") | |
| with gr.Tab("Chat"): | |
| chatbot = gr.Chatbot(type="messages") | |
| chatbot.clear( | |
| clear_chat_and_source_history, | |
| inputs=[source_history, chat_hash], | |
| outputs=[source_history, chat_hash], | |
| ) | |
| gr.ChatInterface( | |
| bot, | |
| chatbot=chatbot, | |
| type="messages", | |
| flagging_mode="manual", | |
| editable=True, | |
| flagging_options=["Like", "Dislike", "Spam", "Inappropriate", "Other"], | |
| multimodal=True, | |
| textbox=gr.MultimodalTextbox(file_count="multiple"), | |
| title="DEval Prototype 1", | |
| additional_inputs=[source_history, chat_hash], | |
| additional_outputs=[source_history, chat_hash], | |
| ) | |
| with gr.Tab("Quellen"): | |
| prettified_source_history_md = gr.Markdown() | |
| # Update the source history UI immediately when the source_history value changes | |
| source_history.change(prettify_source_history, source_history, prettified_source_history_md) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| auth=( | |
| app_settings.test_user_name, | |
| app_settings.test_user_password | |
| ), | |
| ssr_mode=False | |
| ) | |