File size: 5,946 Bytes
2eeb3a4
 
 
5cacb85
 
77a5434
627ec3c
5cacb85
 
 
3051bba
2eeb3a4
 
627ec3c
 
 
2eeb3a4
5cacb85
bf0eea7
546b27b
627ec3c
5fea2a2
627ec3c
 
2eeb3a4
5fea2a2
 
3051bba
2eeb3a4
 
 
77a5434
 
 
 
 
627ec3c
 
5cacb85
 
7e508e0
77a5434
 
 
5cacb85
 
627ec3c
 
 
 
 
 
 
5cacb85
627ec3c
 
 
 
 
 
5cacb85
 
 
 
 
4707e45
5cacb85
 
 
 
4759180
5cacb85
 
 
 
 
a959e84
5cacb85
a959e84
 
 
5cacb85
a959e84
 
 
5cacb85
627ec3c
5cacb85
 
627ec3c
7c247f1
2eeb3a4
 
 
 
627ec3c
2eeb3a4
 
 
 
 
627ec3c
 
 
 
 
2eeb3a4
 
67f186b
7c247f1
6e72fd0
7c247f1
2b0b7e6
627ec3c
67f186b
627ec3c
 
2eeb3a4
627ec3c
2eeb3a4
60d25c3
8366a22
ccf1e43
2eeb3a4
7c247f1
038c68a
2eeb3a4
4e96136
8366a22
15d7f94
7c247f1
768cf00
8398286
4e96136
038c68a
 
 
 
 
 
4e96136
cf824b9
 
 
 
 
 
 
 
 
 
 
4e96136
 
8366a22
 
 
768cf00
2eeb3a4
285d1cd
 
 
 
7de342c
 
285d1cd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from typing import Any

import gradio as gr
from langchain.chat_models import init_chat_model
from langchain_core.tools import tool
from langchain_aws import BedrockEmbeddings
from langchain_qdrant import QdrantVectorStore
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import MessagesState, StateGraph, END
from langgraph.prebuilt import ToolNode, tools_condition

import structlog

from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams

import logging_config as _
# from conversation.main import graph
from conversation.generate import generate, trigger_ai_message_with_tool_call
from conversation.chat_source_history import prettify_source_history, build_source_history_object, clear_chat_and_source_history, create_chat_hash
from ingestion.main import ingest_document
from tools.langfuse_client import get_langfuse_handler

from config import app_settings

# Get Langfuse Callback handler
langfuse_handler = get_langfuse_handler()

# Create a logger instance
logger = structlog.get_logger(__name__)

embeddings = BedrockEmbeddings(
    model_id=app_settings.embedding_model,
    region_name=app_settings.llm_region,
    aws_access_key_id=app_settings.aws_access_key_id,
    aws_secret_access_key=app_settings.aws_secret_access_key,
)

llm = init_chat_model(
    app_settings.llm_model,
    model_provider=app_settings.model_provider,
    region_name=app_settings.llm_region,
    aws_access_key_id=app_settings.aws_access_key_id,
    aws_secret_access_key=app_settings.aws_secret_access_key,
)

client = QdrantClient(app_settings.vector_db_url)
if not client.collection_exists(app_settings.vector_db_collection_name):
    client.create_collection(
        collection_name=app_settings.vector_db_collection_name,
        vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE),
        sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)}
    )
# TODO: move to LLM files later
vector_store = QdrantVectorStore(
    client=client,
    collection_name=app_settings.vector_db_collection_name,
    embedding=embeddings,
)

# ------
# Move to `conversation/main`` later
@tool(response_format="content_and_artifact")
def retrieve(query: str):
    """Retrieve information related to a query."""
    retrieved_docs = vector_store.similarity_search(query, k=10)
    serialized = "\n\n".join(
        (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
        for doc in retrieved_docs
    )
    return serialized, retrieved_docs


graph_builder = StateGraph(MessagesState)
memory = MemorySaver()

tool_node = ToolNode([retrieve])

graph_builder.add_node(trigger_ai_message_with_tool_call)
graph_builder.add_node("tool_node", tool_node)
graph_builder.add_node(generate)

graph_builder.set_entry_point("trigger_ai_message_with_tool_call")
graph_builder.add_edge("trigger_ai_message_with_tool_call", "tool_node")
graph_builder.add_edge("tool_node", "generate")
graph_builder.add_edge("generate", END)

graph = graph_builder.compile(checkpointer=memory)
# -----

def bot(message, history, source_history, chat_hash) -> list[Any]:
    """Generate bot response and history from message.
    
    With multi-modal inputs text and each file is treated as separate message.
    """
   
    logger.info("This is the history", history=history)
    
    # enable message edit
    if isinstance(message, str):
        message = {"text": message}
        
    # process files
    for file in message.get("files"):
        logger.info("Received file", file=file)
        ingest_document(file, vector_store)
    
    # create text response
    user_input_prompt = message.get("text")
    logger.info("This is the current chat_hash", chat_hash=chat_hash)

    config = {"configurable": {"thread_id": chat_hash}, "callbacks": [langfuse_handler], "stream": False}

    response = graph.invoke(
        {"messages": [{"role": "user", "content": user_input_prompt}]},
        config=config,
        )
    
    logger.info("Generated a response", response=response)

    logger.info("This is the source history before appending response", source_history=source_history)
    # Append source details of the response to source history
    source_history = build_source_history_object(response, source_history, user_input_prompt)

    return [response["messages"][-1].content], source_history, chat_hash


with gr.Blocks() as demo:
    # Initialize source history as session state variable
    source_history = gr.State([])
    chat_hash = gr.State(create_chat_hash)

    gr.Markdown("### Dies ist der aktuelle Prototyp der DEval AI4Kontextanalysen, mit dem Fokus auf dem Testen der Zitationsfunktion und der grundlegenden Konfiguration.")
    with gr.Tab("Chat"):
        chatbot = gr.Chatbot(type="messages")
        chatbot.clear(
            clear_chat_and_source_history,
            inputs=[source_history, chat_hash],
            outputs=[source_history, chat_hash],
        )
        gr.ChatInterface(
        bot,
        chatbot=chatbot,
        type="messages",
        flagging_mode="manual",
        editable=True, 
        flagging_options=["Like", "Dislike", "Spam", "Inappropriate", "Other"],
        multimodal=True,
        textbox=gr.MultimodalTextbox(file_count="multiple"),
        title="DEval Prototype 1",
        additional_inputs=[source_history, chat_hash],
        additional_outputs=[source_history, chat_hash],
        )
    with gr.Tab("Quellen"):
        prettified_source_history_md = gr.Markdown()
        # Update the source history UI immediately when the source_history value changes
        source_history.change(prettify_source_history, source_history, prettified_source_history_md)

if __name__ == "__main__":
    demo.launch(
        auth=(
            app_settings.test_user_name,
            app_settings.test_user_password
        ),
        ssr_mode=False
    )