File size: 5,254 Bytes
88a1870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec7e160
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import uuid

import gradio as gr
from agent import construct_agent, initialize_agent_llm, initialize_instrumentor
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage, SystemMessage
from openinference.instrumentation import using_session
from opentelemetry.trace import Status, StatusCode
from rag import initialize_vector_store

load_dotenv()

SYSTEM_MESSAGE_FOR_AGENT_WORKFLOW = """
    You are a Retrieval-Augmented Generation (RAG) assistant designed to provide responses by leveraging provided tools
    Your goal is to ensure the user's query is addressed with quality. If further clarification is required,
    you can request additional input from the user.
    """


def initialize_agent(phoenix_key, project_name, openai_key, user_session_id, vector_source_web_url):
    from tools import initialize_tool_llm

    os.environ["PHOENIX_API_KEY"] = phoenix_key
    os.environ["OPENAI_API_KEY"] = openai_key
    os.environ["PHOENIX_COLLECTOR_ENDPOINT"] = "https://app.phoenix.arize.com/v1/traces"
    agent_tracer = initialize_instrumentor(project_name)
    initialize_agent_llm("gpt-4o-mini")
    tool_model = initialize_tool_llm("gpt-4o-mini")
    initialize_vector_store(vector_source_web_url)
    copilot_agent = construct_agent()
    return (
        copilot_agent,
        agent_tracer,
        tool_model,
        user_session_id,
        (f"Configuration Set: Project " f"'{project_name}' is Ready!"),
    )


def chat_with_agent(
    copilot_agent,
    agent_tracer,
    tool_model,
    user_input_message,
    user_session_id,
    user_chat_history,
    conversation_history,
):
    if not agent:
        return "Error: RAG Agent is not initialized. Please set API keys first."
    if not conversation_history:
        messages = [SystemMessage(content=SYSTEM_MESSAGE_FOR_AGENT_WORKFLOW)]
    else:
        messages = conversation_history["messages"]
    messages.append(HumanMessage(content=user_input_message))
    with using_session(session_id=user_session_id):
        with agent_tracer.start_as_current_span(
            f"agent-{user_session_id}",
            openinference_span_kind="chain",
        ) as span:
            span.set_input(user_input_message)
            conversation_history = copilot_agent.invoke(
                {"messages": messages},
                config={
                    "configurable": {
                        "thread_id": user_session_id,
                        "user_session_id": user_session_id,
                        "tool_model": tool_model,
                    }
                },
            )
            span.set_output(conversation_history["messages"][-1].content)
            span.set_status(Status(StatusCode.OK))

            user_chat_history.append(
                (user_input_message, conversation_history["messages"][-1].content)
            )
            return (
                copilot_agent,
                "",
                user_chat_history,
                user_session_id,
                user_chat_history,
                conversation_history,
            )


with gr.Blocks() as demo:
    agent = gr.State(None)
    tracer = gr.State(None)
    openai_tool_model = gr.State(None)
    history = gr.State({})  # State to maintain the message history as a list of tuples
    session_id = gr.State(str(uuid.uuid4()))
    chat_history = gr.State([])
    gr.Markdown("## Chat with RAG Agent 🔥")

    with gr.Row():
        with gr.Column(scale=1, min_width=250):
            gr.Markdown("### Configuration Panel ⚙️")

            phoenix_input = gr.Textbox(label="Phoenix API Key", type="password")
            project_input = gr.Textbox(label="Project Name", value="Agentic Rag")
            openai_input = gr.Textbox(label="OpenAI API Key", type="password")
            web_url = gr.Textbox(
                label="Vector Source Web URL",
                value="https://lilianweng.github.io/posts/2023-06-23-agent/",
            )
            set_button = gr.Button("Set API Keys & Initialize")
            output_message = gr.Textbox(label="Status", interactive=False)

            set_button.click(
                fn=initialize_agent,
                inputs=[phoenix_input, project_input, openai_input, session_id, web_url],
                outputs=[agent, tracer, openai_tool_model, session_id, output_message],
            )

        with gr.Column(scale=4):
            gr.Markdown("### Chat with RAG Agent 💬")

            chat_display = gr.Chatbot(label="Chat History", height=400)

            user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
            submit_button = gr.Button("Send")

            submit_button.click(
                fn=chat_with_agent,
                inputs=[
                    agent,
                    tracer,
                    openai_tool_model,
                    user_input,
                    session_id,
                    chat_display,
                    history,
                ],
                outputs=[agent, user_input, chat_display, session_id, chat_history, history],
            )

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 10000))
    demo.launch(share=True,server_name="0.0.0.0", server_port=port)