MusaR commited on
Commit
4c8af35
·
verified ·
1 Parent(s): c3d1dc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -124
app.py CHANGED
@@ -1,142 +1,172 @@
1
- # app.py (The "Glass Box" UI)
2
-
3
- import streamlit as st
4
  import google.generativeai as genai
5
  from tavily import TavilyClient
6
- from rag_agent import ResearchAgent, AITools # We now import the agent class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # --- Page Configuration ---
9
- st.set_page_config(page_title="DeepSearch Agent", layout="wide", initial_sidebar_state="expanded")
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # --- CSS for a professional look ---
12
- st.markdown("""
13
- <style>
14
- /* Main container and text styling */
15
- .stApp {
16
- background-color: #0F172A; /* Slate 900 */
17
- color: #E2E8F0; /* Slate 200 */
18
- }
19
- h1, h2, h3 {
20
- color: #F8FAFC; /* Slate 50 */
21
- }
22
- /* Expander for logs */
23
- .streamlit-expanderHeader {
24
- background-color: #1E293B; /* Slate 800 */
25
- color: #E2E8F0;
26
- }
27
- /* Chat input styling */
28
- .stTextInput > div > div > input {
29
- background-color: #1E293B;
30
- color: #E2E8F0;
31
- }
32
- /* Button styling */
33
- .stButton > button {
34
- background-color: #2563EB; /* Blue 600 */
35
- color: white;
36
- border: none;
37
- }
38
- .stButton > button:hover {
39
- background-color: #1D4ED8; /* Blue 700 */
40
- }
41
- </style>
42
- """, unsafe_allow_html=True)
43
 
44
- # --- Session State Management ---
45
- if "agent" not in st.session_state:
46
- st.session_state.agent = None
47
- if "messages" not in st.session_state:
48
- st.session_state.messages = []
49
- if "agent_state" not in st.session_state:
50
- st.session_state.agent_state = "INITIAL" # INITIAL, CLARIFYING, GENERATING
51
- if "initial_topic" not in st.session_state:
52
- st.session_state.initial_topic = ""
53
 
54
- # --- Sidebar for API Keys ---
55
- with st.sidebar:
56
- st.header("🔑 API Configuration")
57
- google_key = st.text_input("Google Gemini API Key", type="password", key="google_api_key")
58
- tavily_key = st.text_input("Tavily API Key", type="password", key="tavily_api_key")
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- if st.button("Initialize Agent"):
61
- if google_key and tavily_key:
62
- with st.spinner("Initializing Agent's RAG models... This may take a moment."):
63
- tools = AITools(api_keys={'google': google_key, 'tavily': tavily_key})
64
- st.session_state.agent = ResearchAgent(tools_instance=tools)
65
- st.success("Agent Initialized!")
66
- st.session_state.agent_state = "READY"
67
- else:
68
- st.error("Please provide all API keys.")
69
 
70
- # --- Main Application UI ---
71
- st.title("Mini DeepSearch Agent")
72
- st.markdown("<p>Your AI partner for in-depth research and analysis.</p>", unsafe_allow_html=True)
 
 
 
 
 
73
 
74
- # Display chat history
75
- for message in st.session_state.messages:
76
- with st.chat_message(message["role"]):
77
- st.markdown(message["content"])
 
 
 
 
 
 
 
 
 
78
 
79
- # Main chat input logic
80
- if st.session_state.agent_state != "INITIAL":
81
- if prompt := st.chat_input("Enter your research topic or answer the questions..."):
82
 
83
- # Add user message to chat
84
- st.session_state.messages.append({"role": "user", "content": prompt})
85
- with st.chat_message("user"):
86
- st.markdown(prompt)
 
87
 
88
- # Assistant's response
89
- with st.chat_message("assistant"):
90
- # Create containers for the two-column layout
91
- log_container = st.expander("Agent's Thought Process", expanded=True)
92
- report_container = st.container()
93
 
94
- full_report = ""
95
- log_messages = []
96
-
97
- # --- Logic for different agent states ---
98
- if st.session_state.agent_state == "READY":
99
- st.session_state.initial_topic = prompt
100
- st.session_state.agent_state = "CLARIFYING"
101
-
102
- # First call to the agent to get clarifying questions
103
- for msg_type, content in st.session_state.agent.run(prompt):
104
- if msg_type == 'log':
105
- log_messages.append(content)
106
- with log_container: st.write("\n".join(log_messages))
107
- elif msg_type == 'clarification':
108
- st.session_state.messages.append({"role": "assistant", "content": content})
109
- report_container.markdown(content)
110
-
111
- elif st.session_state.agent_state == "CLARIFYING":
112
- st.session_state.agent_state = "GENERATING"
113
 
114
- # Second call to the agent to continue the run with refinements
115
- report_generator = st.session_state.agent.continue_run(
116
- st.session_state.initial_topic,
117
- user_refinements=prompt
118
- )
119
 
120
- for msg_type, content in report_generator:
121
- if msg_type == 'log':
122
- log_messages.append(content)
123
- with log_container: st.write("\n".join(log_messages))
124
-
125
- elif msg_type == 'report_start':
126
- full_report = content
127
- with report_container: st.markdown(full_report)
128
 
129
- elif msg_type == 'report_stream':
130
- # This part handles the live-writing effect
131
- full_report = content # The content is the full report so far
132
- with report_container: st.markdown(full_report + "▌") # Add a blinking cursor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- elif msg_type == 'report_chunk':
135
- # This replaces the streamed content with the final, formatted chunk
136
- full_report += content
137
- with report_container: st.markdown(full_report)
 
 
 
 
 
 
138
 
139
- # Final update after the loop
140
- with report_container: st.markdown(full_report)
141
- st.session_state.messages.append({"role": "assistant", "content": full_report})
142
- st.session_state.agent_state = "READY" # Reset for the next query
 
1
+ import os
2
+ import gradio as gr
 
3
  import google.generativeai as genai
4
  from tavily import TavilyClient
5
+ from sentence_transformers import SentenceTransformer, CrossEncoder
6
+
7
+ from research_agent.config import AgentConfig
8
+ from research_agent.agent import get_clarifying_questions, research_and_plan, write_report_stream
9
+
10
+ # --- CSS for a professional, ChatGPT-inspired look ---
11
+ CSS = """
12
+ body, .gradio-container { font-family: 'Inter', sans-serif; background-color: #343541; color: #ECECEC; }
13
+ .gradio-container { max-width: 800px !important; margin: auto !important; padding-top: 2rem !important;}
14
+ h1 { text-align: center; font-weight: 700; font-size: 2.5em; color: white; }
15
+ .sub-header { text-align: center; color: #C5C5D2; margin-bottom: 2rem; font-size: 1.1em; }
16
+ .accordion { background-color: #40414F; border: 1px solid #565869 !important; border-radius: 8px !important; }
17
+ .accordion .gr-button { background-color: #4B4C5A; color: white; }
18
+ #chatbot { box-shadow: none !important; border: none !important; background-color: transparent !important; }
19
+ .message-bubble { background: #40414F !important; border: 1px solid #565869 !important; color: #ECECEC !important;}
20
+ .message-bubble.user { background: #343541 !important; border: none !important; }
21
+ footer { display: none !important; }
22
+ .gr-box.gradio-container { padding: 0 !important; }
23
+ .gr-form { background-color: transparent !important; border: none !important; box-shadow: none !important; }
24
+ .gradio-container .gr-form .gr-button { display: none; } /* Hide the default submit button */
25
+ #chat-input-container { position: relative; }
26
+ #chat-input-container textarea { background-color: #40414F; color: white; border: 1px solid #565869 !important; }
27
+ #submit-button { position: absolute; right: 10px; top: 50%; transform: translateY(-50%); background: #2563EB; color: white; border-radius: 4px; padding: 4px 8px; }
28
+ """
29
+
30
+ # --- Model Initialization ---
31
+ config = AgentConfig()
32
+ writer_model, planner_model, embedding_model, reranker, tavily_client = None, None, None, None, None
33
+ IS_PROCESSING = False # Add a lock to prevent concurrent runs
34
 
35
+ def initialize_models(google_key, tavily_key):
36
+ global writer_model, planner_model, embedding_model, reranker, tavily_client, IS_PROCESSING
37
+ if not google_key or not tavily_key:
38
+ raise gr.Error("API keys are required.")
39
+ try:
40
+ genai.configure(api_key=google_key)
41
+ tavily_client = TavilyClient(api_key=tavily_key)
42
+ writer_model = genai.GenerativeModel(config.WRITER_MODEL)
43
+ planner_model = genai.GenerativeModel(config.WRITER_MODEL)
44
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
45
+ reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device='cpu')
46
+ except Exception as e:
47
+ raise gr.Error(f"Failed to initialize models. Error: {str(e)}")
48
+ IS_PROCESSING = False # Ensure lock is free on initialization
49
 
50
+ # --- Gradio Application Logic ---
51
+ with gr.Blocks(css=CSS, theme=gr.themes.Base()) as app:
52
+ gr.Markdown("<h1>Mini DeepSearch Agent</h1>")
53
+ gr.Markdown("<p class='sub-header'>Your AI partner for in-depth research and analysis.</p>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ agent_state = gr.State("INITIAL")
56
+ initial_topic_state = gr.State("")
 
 
 
 
 
 
 
57
 
58
+ with gr.Accordion("API & Settings", open=True, elem_classes="accordion") as settings_accordion:
59
+ with gr.Row():
60
+ google_api_key_input = gr.Textbox(label="Google API Key", type="password", placeholder="Enter Google AI API Key", scale=2)
61
+ tavily_api_key_input = gr.Textbox(label="Tavily API Key", type="password", placeholder="Enter Tavily Search API Key", scale=2)
62
+ init_button = gr.Button("Initialize Agent", scale=1)
63
+
64
+ chatbot = gr.Chatbot(
65
+ elem_id="chatbot",
66
+ bubble_full_width=False,
67
+ height=500,
68
+ visible=False
69
+ )
70
+
71
+ with gr.Row(elem_id="chat-input-container"):
72
+ chat_input = gr.Textbox(placeholder="What would you like to research?", interactive=False, visible=False, show_label=False, scale=8)
73
+ submit_button = gr.Button("Submit", elem_id="submit-button", visible=False, scale=1)
74
 
75
+ def handle_initialization(google_key, tavily_key):
76
+ initialize_models(google_key, tavily_key)
77
+ return {
78
+ chatbot: gr.update(visible=True, value=[(None, "Agent initialized. Please enter your research topic to begin.")]),
79
+ chat_input: gr.update(interactive=True, visible=True),
80
+ submit_button: gr.update(visible=True),
81
+ settings_accordion: gr.update(open=False)
82
+ }
 
83
 
84
+ def chat_step_wrapper(user_input, history, current_agent_state, topic_state):
85
+ """A wrapper to manage the processing lock."""
86
+ global IS_PROCESSING
87
+ if IS_PROCESSING:
88
+ print("Ignoring duplicate request while processing.")
89
+ if False: # This makes the function a generator
90
+ yield
91
+ return
92
 
93
+ IS_PROCESSING = True
94
+ try:
95
+ # Yield all updates from the actual agent logic
96
+ for update in chat_step(user_input, history, current_agent_state, topic_state):
97
+ yield update
98
+ except Exception as e:
99
+ error_message = f"An error occurred: {str(e)}"
100
+ history.append((None, error_message))
101
+ yield history, "INITIAL", "", gr.update(interactive=True, placeholder="Let's try again. What's the topic?")
102
+ finally:
103
+ # Release the lock once the generator is exhausted or an error occurs
104
+ IS_PROCESSING = False
105
+ print("Processing finished. Lock released.")
106
 
107
+ def chat_step(user_input, history, current_agent_state, topic_state):
108
+ history = history or []
109
+ history.append((user_input, None))
110
 
111
+ if current_agent_state == "INITIAL":
112
+ yield history, "CLARIFYING", user_input, gr.update(interactive=False, placeholder="Thinking...")
113
+ questions = get_clarifying_questions(planner_model, user_input)
114
+ history[-1] = (user_input, "To give you the best report, could you answer these questions for me?\n\n" + questions)
115
+ yield history, "CLARIFYING", user_input, gr.update(interactive=True, placeholder="Provide your answers to the questions above...")
116
 
117
+ elif current_agent_state == "CLARIFYING":
118
+ thinking_message = "Got it. Generating your full research report. This will take a moment..."
119
+ history[-1] = (user_input, thinking_message)
120
+ yield history, "GENERATING", topic_state, gr.update(interactive=False, placeholder="Generating...")
 
121
 
122
+ try:
123
+ plan = research_and_plan(config, planner_model, tavily_client, topic_state, user_input)
124
+ report_generator = write_report_stream(config, writer_model, tavily_client, embedding_model, reranker, plan)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ stream_content = ""
127
+ for update in report_generator:
128
+ stream_content = update
129
+ history[-1] = (user_input, stream_content)
130
+ yield history, "GENERATING", topic_state, gr.update(interactive=False)
131
 
132
+ yield history, "INITIAL", "", gr.update(interactive=True, placeholder="Research complete. What's the next topic?")
133
+
134
+ except Exception as e:
135
+ error_message = f"An error occurred: {str(e)}"
136
+ history.append((None, error_message))
137
+ yield history, "INITIAL", "", gr.update(interactive=True, placeholder="Let's try again. What's the topic?")
 
 
138
 
139
+ # --- Event Listeners ---
140
+ # This section is rewritten to prevent duplicate triggers.
141
+
142
+ init_button.click(
143
+ fn=handle_initialization,
144
+ inputs=[google_api_key_input, tavily_api_key_input],
145
+ outputs=[chatbot, chat_input, submit_button, settings_accordion]
146
+ )
147
+
148
+ # We define a single submission event and trigger it from both the button and the textbox.
149
+ # It now calls the wrapper function to handle the processing lock.
150
+ submit_event = submit_button.click(
151
+ fn=chat_step_wrapper,
152
+ inputs=[chat_input, chatbot, agent_state, initial_topic_state],
153
+ outputs=[chatbot, agent_state, initial_topic_state, chat_input],
154
+ ).then(
155
+ fn=lambda: gr.update(value=""),
156
+ inputs=None,
157
+ outputs=[chat_input],
158
+ queue=False
159
+ )
160
 
161
+ chat_input.submit(
162
+ fn=chat_step_wrapper,
163
+ inputs=[chat_input, chatbot, agent_state, initial_topic_state],
164
+ outputs=[chatbot, agent_state, initial_topic_state, chat_input],
165
+ ).then(
166
+ fn=lambda: gr.update(value=""),
167
+ inputs=None,
168
+ outputs=[chat_input],
169
+ queue=False
170
+ )
171
 
172
+ app.launch(debug=True)