Julia Ostheimer
Potentially fix deployment issue: Remove server specs in Gradio launch method
c744766
from typing import Any
import gradio as gr
from langchain.chat_models import init_chat_model
from langchain_core.tools import tool
from langchain_aws import BedrockEmbeddings
from langchain_qdrant import QdrantVectorStore
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import MessagesState, StateGraph, END
from langgraph.prebuilt import ToolNode, tools_condition
import structlog
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams
import logging_config as _
# from conversation.main import graph
from conversation.generate import generate, trigger_ai_message_with_tool_call
from conversation.chat_source_history import prettify_source_history, build_source_history_object, clear_chat_and_source_history, create_chat_hash
from ingestion.main import ingest_document
from tools.langfuse_client import get_langfuse_handler
from config import app_settings
# Get Langfuse Callback handler
langfuse_handler = get_langfuse_handler()
# Create a logger instance
logger = structlog.get_logger(__name__)
embeddings = BedrockEmbeddings(
model_id=app_settings.embedding_model,
region_name=app_settings.llm_region,
aws_access_key_id=app_settings.aws_access_key_id,
aws_secret_access_key=app_settings.aws_secret_access_key,
)
llm = init_chat_model(
app_settings.llm_model,
model_provider=app_settings.model_provider,
region_name=app_settings.llm_region,
aws_access_key_id=app_settings.aws_access_key_id,
aws_secret_access_key=app_settings.aws_secret_access_key,
)
client = QdrantClient(app_settings.vector_db_url)
if not client.collection_exists(app_settings.vector_db_collection_name):
client.create_collection(
collection_name=app_settings.vector_db_collection_name,
vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE),
sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)}
)
# TODO: move to LLM files later
vector_store = QdrantVectorStore(
client=client,
collection_name=app_settings.vector_db_collection_name,
embedding=embeddings,
)
# ------
# Move to `conversation/main`` later
@tool(response_format="content_and_artifact")
def retrieve(query: str):
"""Retrieve information related to a query."""
retrieved_docs = vector_store.similarity_search(query, k=10)
serialized = "\n\n".join(
(f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
for doc in retrieved_docs
)
return serialized, retrieved_docs
graph_builder = StateGraph(MessagesState)
memory = MemorySaver()
tool_node = ToolNode([retrieve])
graph_builder.add_node(trigger_ai_message_with_tool_call)
graph_builder.add_node("tool_node", tool_node)
graph_builder.add_node(generate)
graph_builder.set_entry_point("trigger_ai_message_with_tool_call")
graph_builder.add_edge("trigger_ai_message_with_tool_call", "tool_node")
graph_builder.add_edge("tool_node", "generate")
graph_builder.add_edge("generate", END)
graph = graph_builder.compile(checkpointer=memory)
# -----
def bot(message, history, source_history, chat_hash) -> list[Any]:
"""Generate bot response and history from message.
With multi-modal inputs text and each file is treated as separate message.
"""
logger.info("This is the history", history=history)
# enable message edit
if isinstance(message, str):
message = {"text": message}
# process files
for file in message.get("files"):
logger.info("Received file", file=file)
ingest_document(file, vector_store)
# create text response
user_input_prompt = message.get("text")
logger.info("This is the current chat_hash", chat_hash=chat_hash)
config = {"configurable": {"thread_id": chat_hash}, "callbacks": [langfuse_handler], "stream": False}
response = graph.invoke(
{"messages": [{"role": "user", "content": user_input_prompt}]},
config=config,
)
logger.info("Generated a response", response=response)
logger.info("This is the source history before appending response", source_history=source_history)
# Append source details of the response to source history
source_history = build_source_history_object(response, source_history, user_input_prompt)
return [response["messages"][-1].content], source_history, chat_hash
with gr.Blocks() as demo:
# Initialize source history as session state variable
source_history = gr.State([])
chat_hash = gr.State(create_chat_hash)
gr.Markdown("### Dies ist der aktuelle Prototyp der DEval AI4Kontextanalysen, mit dem Fokus auf dem Testen der Zitationsfunktion und der grundlegenden Konfiguration.")
with gr.Tab("Chat"):
chatbot = gr.Chatbot(type="messages")
chatbot.clear(
clear_chat_and_source_history,
inputs=[source_history, chat_hash],
outputs=[source_history, chat_hash],
)
gr.ChatInterface(
bot,
chatbot=chatbot,
type="messages",
flagging_mode="manual",
editable=True,
flagging_options=["Like", "Dislike", "Spam", "Inappropriate", "Other"],
multimodal=True,
textbox=gr.MultimodalTextbox(file_count="multiple"),
title="DEval Prototype 1",
additional_inputs=[source_history, chat_hash],
additional_outputs=[source_history, chat_hash],
)
with gr.Tab("Quellen"):
prettified_source_history_md = gr.Markdown()
# Update the source history UI immediately when the source_history value changes
source_history.change(prettify_source_history, source_history, prettified_source_history_md)
if __name__ == "__main__":
demo.launch(
auth=(
app_settings.test_user_name,
app_settings.test_user_password
),
ssr_mode=False
)