MusaR commited on
Commit
172912a
·
verified ·
1 Parent(s): 5330eae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -165
app.py CHANGED
@@ -1,166 +1,166 @@
1
- import os
2
- import gradio as gr
3
- import google.generativeai as genai
4
- from tavily import TavilyClient
5
- from sentence_transformers import SentenceTransformer, CrossEncoder
6
-
7
- from research_agent.config import AgentConfig
8
- from research_agent.agent import get_clarifying_questions, research_and_plan, write_report_stream
9
-
10
- # --- CSS for styling the Gradio app ---
11
- CSS = """
12
- body { font-family: 'Inter', sans-serif; background-color: #F0F2F6; }
13
- .gradio-container { max-width: 960px !important; margin: auto !important; }
14
- h1 { text-align: center; font-size: 2.5em; color: #1E3A8A; }
15
- .gr-button { background-color: #2563EB; color: white; }
16
- .gr-button:hover { background-color: #1E4ED8; }
17
- .status_box {
18
- background-color: #FFFFFF;
19
- border-radius: 8px;
20
- padding: 15px;
21
- border: 1px solid #E5E7EB;
22
- box-shadow: 0 2px 4px rgba(0,0,0,0.05);
23
- }
24
- .report_output {
25
- background-color: #FFFFFF;
26
- border-radius: 8px;
27
- padding: 20px;
28
- border: 1px solid #E5E7EB;
29
- box-shadow: 0 4px 8px rgba(0,0,0,0.05);
30
- }
31
- """
32
-
33
- # --- Global variables for models (to avoid reloading) ---
34
- writer_model = None
35
- planner_model = None
36
- embedding_model = None
37
- reranker = None
38
- tavily_client = None
39
- config = AgentConfig()
40
-
41
- def initialize_models(google_api_key, tavily_api_key):
42
- """Initializes all the necessary models and API clients."""
43
- global writer_model, planner_model, embedding_model, reranker, tavily_client
44
-
45
- if not google_api_key or not tavily_api_key:
46
- raise gr.Error("API keys are required. Please provide both Google and Tavily API keys.")
47
-
48
- try:
49
- genai.configure(api_key=google_api_key)
50
- tavily_client = TavilyClient(api_key=tavily_api_key)
51
-
52
- if writer_model is None:
53
- writer_model = genai.GenerativeModel(config.WRITER_MODEL)
54
- if planner_model is None:
55
- planner_model = genai.GenerativeModel(config.WRITER_MODEL)
56
- if embedding_model is None:
57
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
58
- if reranker is None:
59
- reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device='cpu')
60
-
61
- return "Models initialized successfully!"
62
- except Exception as e:
63
- raise gr.Error(f"Failed to initialize models. Please check your API keys. Error: {str(e)}")
64
-
65
- def start_research_phase(topic, google_key, tavily_key):
66
- """Phase 1: Get user topic and return clarifying questions."""
67
- initialize_models(google_key, tavily_key)
68
-
69
- if not topic:
70
- raise gr.Error("Research topic cannot be empty.")
71
-
72
- questions = get_clarifying_questions(planner_model, topic)
73
-
74
- # Show the next stage of the UI
75
- return {
76
- clarification_ui: gr.update(visible=True),
77
- clarification_questions_display: gr.update(value=questions),
78
- initial_ui: gr.update(visible=False)
79
- }
80
-
81
- def generate_report_phase(topic, answers, google_key, tavily_key):
82
- """Phase 2: Take answers and generate the full report, streaming progress."""
83
- initialize_models(google_key, tavily_key)
84
-
85
- status_updates = "### Agent Status\n"
86
- yield {
87
- status_box: gr.update(value=status_updates + "-> Planning research...\n"),
88
- final_report: gr.update(value=None)
89
- }
90
-
91
- try:
92
- plan = research_and_plan(config, planner_model, tavily_client, topic, answers)
93
- except Exception as e:
94
- raise gr.Error(f"Failed during planning phase: {e}")
95
-
96
- status_updates += f"**Research Plan:**\n- **Topic:** {plan['detailed_topic']}\n- **Sections:** {[s.title for s in plan['sections']]}\n\n---\n"
97
- yield { status_box: gr.update(value=status_updates) }
98
-
99
- report_generator = write_report_stream(config, writer_model, tavily_client, embedding_model, reranker, plan)
100
-
101
- final_report_md = ""
102
- for update in report_generator:
103
- if isinstance(update, str):
104
- final_report_md = update
105
- status_updates += update
106
- yield { status_box: gr.update(value=status_updates) }
107
-
108
- yield { final_report: gr.update(value=final_report_md) }
109
-
110
- # --- Build the Gradio Interface ---
111
- with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as app:
112
-
113
- gr.Markdown("# Mini DeepSearch Agent")
114
- gr.Markdown("This agent performs in-depth research on a given topic, using AI to plan, search, and write a comprehensive report.")
115
-
116
- # State to hold the original topic
117
- topic_state = gr.State()
118
-
119
- # --- UI Stage 1: Initial Query ---
120
- with gr.Box(visible=True) as initial_ui:
121
- with gr.Row():
122
- google_api_key_input = gr.Textbox(label="Google API Key", type="password", placeholder="Enter your Google AI API Key")
123
- tavily_api_key_input = gr.Textbox(label="Tavily API Key", type="password", placeholder="Enter your Tavily Search API Key")
124
-
125
- topic_input = gr.Textbox(label="Research Topic", placeholder="e.g., The future of renewable energy")
126
- start_button = gr.Button("Start Research", variant="primary")
127
-
128
- # --- UI Stage 2: Clarification ---
129
- with gr.Box(visible=False) as clarification_ui:
130
- gr.Markdown("### To give you the most relevant report, could you please clarify:")
131
- clarification_questions_display = gr.Markdown(elem_classes="status_box")
132
- clarification_answers_input = gr.Textbox(label="Your Answers", placeholder="Provide your answers to the questions above to tailor the research...")
133
- generate_report_button = gr.Button("Generate Full Report", variant="primary")
134
-
135
- # --- UI Stage 3: Output ---
136
- with gr.Column():
137
- status_box = gr.Markdown(elem_classes="status_box", label="Agent Thought Process", visible=False)
138
- final_report = gr.Markdown(elem_classes="report_output", label="Final Research Report", visible=False)
139
-
140
- # --- Event Handlers ---
141
- def show_outputs():
142
- return {
143
- status_box: gr.update(visible=True),
144
- final_report: gr.update(visible=True)
145
- }
146
-
147
- start_button.click(
148
- fn=start_research_phase,
149
- inputs=[topic_input, google_api_key_input, tavily_api_key_input],
150
- outputs=[initial_ui, clarification_ui, clarification_questions_display]
151
- ).then(
152
- fn=lambda topic: topic,
153
- inputs=[topic_input],
154
- outputs=[topic_state] # Save the topic for the next step
155
- )
156
-
157
- generate_report_button.click(
158
- fn=show_outputs,
159
- outputs=[status_box, final_report]
160
- ).then(
161
- fn=generate_report_phase,
162
- inputs=[topic_state, clarification_answers_input, google_api_key_input, tavily_api_key_input],
163
- outputs=[status_box, final_report]
164
- )
165
-
166
  app.launch(debug=True)
 
1
+ import os
2
+ import gradio as gr
3
+ import google.generativeai as genai
4
+ from tavily import TavilyClient
5
+ from sentence_transformers import SentenceTransformer, CrossEncoder
6
+
7
+ from research_agent.config import AgentConfig
8
+ from research_agent.agent import get_clarifying_questions, research_and_plan, write_report_stream
9
+
10
+ # --- CSS for styling the Gradio app ---
11
+ CSS = """
12
+ body { font-family: 'Inter', sans-serif; background-color: #F0F2F6; }
13
+ .gradio-container { max-width: 960px !important; margin: auto !important; }
14
+ h1 { text-align: center; font-size: 2.5em; color: #1E3A8A; }
15
+ .gr-button { background-color: #2563EB; color: white; }
16
+ .gr-button:hover { background-color: #1E4ED8; }
17
+ .status_box {
18
+ background-color: #FFFFFF;
19
+ border-radius: 8px;
20
+ padding: 15px;
21
+ border: 1px solid #E5E7EB;
22
+ box-shadow: 0 2px 4px rgba(0,0,0,0.05);
23
+ }
24
+ .report_output {
25
+ background-color: #FFFFFF;
26
+ border-radius: 8px;
27
+ padding: 20px;
28
+ border: 1px solid #E5E7EB;
29
+ box-shadow: 0 4px 8px rgba(0,0,0,0.05);
30
+ }
31
+ """
32
+
33
+ # --- Global variables for models (to avoid reloading) ---
34
+ writer_model = None
35
+ planner_model = None
36
+ embedding_model = None
37
+ reranker = None
38
+ tavily_client = None
39
+ config = AgentConfig()
40
+
41
+ def initialize_models(google_api_key, tavily_api_key):
42
+ """Initializes all the necessary models and API clients."""
43
+ global writer_model, planner_model, embedding_model, reranker, tavily_client
44
+
45
+ if not google_api_key or not tavily_api_key:
46
+ raise gr.Error("API keys are required. Please provide both Google and Tavily API keys.")
47
+
48
+ try:
49
+ genai.configure(api_key=google_api_key)
50
+ tavily_client = TavilyClient(api_key=tavily_api_key)
51
+
52
+ if writer_model is None:
53
+ writer_model = genai.GenerativeModel(config.WRITER_MODEL)
54
+ if planner_model is None:
55
+ planner_model = genai.GenerativeModel(config.WRITER_MODEL)
56
+ if embedding_model is None:
57
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
58
+ if reranker is None:
59
+ reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device='cpu')
60
+
61
+ return "Models initialized successfully!"
62
+ except Exception as e:
63
+ raise gr.Error(f"Failed to initialize models. Please check your API keys. Error: {str(e)}")
64
+
65
+ def start_research_phase(topic, google_key, tavily_key):
66
+ """Phase 1: Get user topic and return clarifying questions."""
67
+ initialize_models(google_key, tavily_key)
68
+
69
+ if not topic:
70
+ raise gr.Error("Research topic cannot be empty.")
71
+
72
+ questions = get_clarifying_questions(planner_model, topic)
73
+
74
+ # Show the next stage of the UI
75
+ return {
76
+ clarification_ui: gr.update(visible=True),
77
+ clarification_questions_display: gr.update(value=questions),
78
+ initial_ui: gr.update(visible=False)
79
+ }
80
+
81
+ def generate_report_phase(topic, answers, google_key, tavily_key):
82
+ """Phase 2: Take answers and generate the full report, streaming progress."""
83
+ initialize_models(google_key, tavily_key)
84
+
85
+ status_updates = "### Agent Status\n"
86
+ yield {
87
+ status_box: gr.update(value=status_updates + "-> Planning research...\n"),
88
+ final_report: gr.update(value=None)
89
+ }
90
+
91
+ try:
92
+ plan = research_and_plan(config, planner_model, tavily_client, topic, answers)
93
+ except Exception as e:
94
+ raise gr.Error(f"Failed during planning phase: {e}")
95
+
96
+ status_updates += f"**Research Plan:**\n- **Topic:** {plan['detailed_topic']}\n- **Sections:** {[s.title for s in plan['sections']]}\n\n---\n"
97
+ yield { status_box: gr.update(value=status_updates) }
98
+
99
+ report_generator = write_report_stream(config, writer_model, tavily_client, embedding_model, reranker, plan)
100
+
101
+ final_report_md = ""
102
+ for update in report_generator:
103
+ if isinstance(update, str):
104
+ final_report_md = update
105
+ status_updates += update
106
+ yield { status_box: gr.update(value=status_updates) }
107
+
108
+ yield { final_report: gr.update(value=final_report_md) }
109
+
110
+ # --- Build the Gradio Interface ---
111
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as app:
112
+
113
+ gr.Markdown("# Mini DeepSearch Agent")
114
+ gr.Markdown("This agent performs in-depth research on a given topic, using AI to plan, search, and write a comprehensive report.")
115
+
116
+ # State to hold the original topic
117
+ topic_state = gr.State()
118
+
119
+ # --- UI Stage 1: Initial Query ---
120
+ with gr.Column(visible=True) as initial_ui:
121
+ with gr.Row():
122
+ google_api_key_input = gr.Textbox(label="Google API Key", type="password", placeholder="Enter your Google AI API Key")
123
+ tavily_api_key_input = gr.Textbox(label="Tavily API Key", type="password", placeholder="Enter your Tavily Search API Key")
124
+
125
+ topic_input = gr.Textbox(label="Research Topic", placeholder="e.g., The future of renewable energy")
126
+ start_button = gr.Button("Start Research", variant="primary")
127
+
128
+ # --- UI Stage 2: Clarification ---
129
+ with gr.Column(visible=False) as clarification_ui:
130
+ gr.Markdown("### To give you the most relevant report, could you please clarify:")
131
+ clarification_questions_display = gr.Markdown(elem_classes="status_box")
132
+ clarification_answers_input = gr.Textbox(label="Your Answers", placeholder="Provide your answers to the questions above to tailor the research...")
133
+ generate_report_button = gr.Button("Generate Full Report", variant="primary")
134
+
135
+ # --- UI Stage 3: Output ---
136
+ with gr.Column():
137
+ status_box = gr.Markdown(elem_classes="status_box", label="Agent Thought Process", visible=False)
138
+ final_report = gr.Markdown(elem_classes="report_output", label="Final Research Report", visible=False)
139
+
140
+ # --- Event Handlers ---
141
+ def show_outputs():
142
+ return {
143
+ status_box: gr.update(visible=True),
144
+ final_report: gr.update(visible=True)
145
+ }
146
+
147
+ start_button.click(
148
+ fn=start_research_phase,
149
+ inputs=[topic_input, google_api_key_input, tavily_api_key_input],
150
+ outputs=[initial_ui, clarification_ui, clarification_questions_display]
151
+ ).then(
152
+ fn=lambda topic: topic,
153
+ inputs=[topic_input],
154
+ outputs=[topic_state] # Save the topic for the next step
155
+ )
156
+
157
+ generate_report_button.click(
158
+ fn=show_outputs,
159
+ outputs=[status_box, final_report]
160
+ ).then(
161
+ fn=generate_report_phase,
162
+ inputs=[topic_state, clarification_answers_input, google_api_key_input, tavily_api_key_input],
163
+ outputs=[status_box, final_report]
164
+ )
165
+
166
  app.launch(debug=True)