File size: 5,946 Bytes
2eeb3a4 5cacb85 77a5434 627ec3c 5cacb85 3051bba 2eeb3a4 627ec3c 2eeb3a4 5cacb85 bf0eea7 546b27b 627ec3c 5fea2a2 627ec3c 2eeb3a4 5fea2a2 3051bba 2eeb3a4 77a5434 627ec3c 5cacb85 7e508e0 77a5434 5cacb85 627ec3c 5cacb85 627ec3c 5cacb85 4707e45 5cacb85 4759180 5cacb85 a959e84 5cacb85 a959e84 5cacb85 a959e84 5cacb85 627ec3c 5cacb85 627ec3c 7c247f1 2eeb3a4 627ec3c 2eeb3a4 627ec3c 2eeb3a4 67f186b 7c247f1 6e72fd0 7c247f1 2b0b7e6 627ec3c 67f186b 627ec3c 2eeb3a4 627ec3c 2eeb3a4 60d25c3 8366a22 ccf1e43 2eeb3a4 7c247f1 038c68a 2eeb3a4 4e96136 8366a22 15d7f94 7c247f1 768cf00 8398286 4e96136 038c68a 4e96136 cf824b9 4e96136 8366a22 768cf00 2eeb3a4 285d1cd 7de342c 285d1cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
from typing import Any
import gradio as gr
from langchain.chat_models import init_chat_model
from langchain_core.tools import tool
from langchain_aws import BedrockEmbeddings
from langchain_qdrant import QdrantVectorStore
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import MessagesState, StateGraph, END
from langgraph.prebuilt import ToolNode, tools_condition
import structlog
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams
import logging_config as _
# from conversation.main import graph
from conversation.generate import generate, trigger_ai_message_with_tool_call
from conversation.chat_source_history import prettify_source_history, build_source_history_object, clear_chat_and_source_history, create_chat_hash
from ingestion.main import ingest_document
from tools.langfuse_client import get_langfuse_handler
from config import app_settings
# Get Langfuse Callback handler
langfuse_handler = get_langfuse_handler()
# Create a logger instance
logger = structlog.get_logger(__name__)
embeddings = BedrockEmbeddings(
model_id=app_settings.embedding_model,
region_name=app_settings.llm_region,
aws_access_key_id=app_settings.aws_access_key_id,
aws_secret_access_key=app_settings.aws_secret_access_key,
)
llm = init_chat_model(
app_settings.llm_model,
model_provider=app_settings.model_provider,
region_name=app_settings.llm_region,
aws_access_key_id=app_settings.aws_access_key_id,
aws_secret_access_key=app_settings.aws_secret_access_key,
)
client = QdrantClient(app_settings.vector_db_url)
if not client.collection_exists(app_settings.vector_db_collection_name):
client.create_collection(
collection_name=app_settings.vector_db_collection_name,
vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE),
sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)}
)
# TODO: move to LLM files later
vector_store = QdrantVectorStore(
client=client,
collection_name=app_settings.vector_db_collection_name,
embedding=embeddings,
)
# ------
# Move to `conversation/main`` later
@tool(response_format="content_and_artifact")
def retrieve(query: str):
"""Retrieve information related to a query."""
retrieved_docs = vector_store.similarity_search(query, k=10)
serialized = "\n\n".join(
(f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
for doc in retrieved_docs
)
return serialized, retrieved_docs
graph_builder = StateGraph(MessagesState)
memory = MemorySaver()
tool_node = ToolNode([retrieve])
graph_builder.add_node(trigger_ai_message_with_tool_call)
graph_builder.add_node("tool_node", tool_node)
graph_builder.add_node(generate)
graph_builder.set_entry_point("trigger_ai_message_with_tool_call")
graph_builder.add_edge("trigger_ai_message_with_tool_call", "tool_node")
graph_builder.add_edge("tool_node", "generate")
graph_builder.add_edge("generate", END)
graph = graph_builder.compile(checkpointer=memory)
# -----
def bot(message, history, source_history, chat_hash) -> list[Any]:
"""Generate bot response and history from message.
With multi-modal inputs text and each file is treated as separate message.
"""
logger.info("This is the history", history=history)
# enable message edit
if isinstance(message, str):
message = {"text": message}
# process files
for file in message.get("files"):
logger.info("Received file", file=file)
ingest_document(file, vector_store)
# create text response
user_input_prompt = message.get("text")
logger.info("This is the current chat_hash", chat_hash=chat_hash)
config = {"configurable": {"thread_id": chat_hash}, "callbacks": [langfuse_handler], "stream": False}
response = graph.invoke(
{"messages": [{"role": "user", "content": user_input_prompt}]},
config=config,
)
logger.info("Generated a response", response=response)
logger.info("This is the source history before appending response", source_history=source_history)
# Append source details of the response to source history
source_history = build_source_history_object(response, source_history, user_input_prompt)
return [response["messages"][-1].content], source_history, chat_hash
with gr.Blocks() as demo:
# Initialize source history as session state variable
source_history = gr.State([])
chat_hash = gr.State(create_chat_hash)
gr.Markdown("### Dies ist der aktuelle Prototyp der DEval AI4Kontextanalysen, mit dem Fokus auf dem Testen der Zitationsfunktion und der grundlegenden Konfiguration.")
with gr.Tab("Chat"):
chatbot = gr.Chatbot(type="messages")
chatbot.clear(
clear_chat_and_source_history,
inputs=[source_history, chat_hash],
outputs=[source_history, chat_hash],
)
gr.ChatInterface(
bot,
chatbot=chatbot,
type="messages",
flagging_mode="manual",
editable=True,
flagging_options=["Like", "Dislike", "Spam", "Inappropriate", "Other"],
multimodal=True,
textbox=gr.MultimodalTextbox(file_count="multiple"),
title="DEval Prototype 1",
additional_inputs=[source_history, chat_hash],
additional_outputs=[source_history, chat_hash],
)
with gr.Tab("Quellen"):
prettified_source_history_md = gr.Markdown()
# Update the source history UI immediately when the source_history value changes
source_history.change(prettify_source_history, source_history, prettified_source_history_md)
if __name__ == "__main__":
demo.launch(
auth=(
app_settings.test_user_name,
app_settings.test_user_password
),
ssr_mode=False
)
|