Upload 5 files
Browse files- app.py +348 -0
- deep_research.py +185 -0
- main_agent.py +208 -0
- requirements.txt +14 -0
- table_maker.py +106 -0
app.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from main_agent import Main_agent
|
| 4 |
+
import asyncio
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
import requests
|
| 10 |
+
from spire.doc import Document, FileFormat
|
| 11 |
+
import nest_asyncio
|
| 12 |
+
import platform
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
|
| 15 |
+
# Apply nest_asyncio to allow nested event loops
|
| 16 |
+
nest_asyncio.apply()
|
| 17 |
+
|
| 18 |
+
# Create event loop
|
| 19 |
+
if platform.system() == 'Windows':
|
| 20 |
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
| 21 |
+
|
| 22 |
+
test_data={'title': 'Sunday App and Stripe: A Payment Strategy for Restaurants',
|
| 23 |
+
'image_url': 'https://www.menutiger.com/_next/image?url=http%3A%2F%2Fcms.menutiger.com%2Fwp-content%2Fuploads%2F2023%2F07%2Fseamless-stripe-payment-integration-for-restaurants-2.jpg&w=2048&q=75',
|
| 24 |
+
'paragraphs': [{'title': 'Introduction',
|
| 25 |
+
'content': 'Sunday is a payment solution designed to streamline the restaurant experience. Co-founded in April 2021 by Christine de Wendel, Victor Lugger, and Tigrane Seydoux, it aims to address inefficiencies in the restaurant industry by simplifying payments through QR code technology, allowing customers to quickly access their bill, split it, add a tip, and pay in seconds. Integrated with existing POS systems and partnering with payment providers like Stripe and Checkout.com, Sunday also offers features such as digital menus with customizable options, order and pay capabilities, and tools to enhance online reputation and customer engagement for restaurants.'},
|
| 26 |
+
{'title': "Sunday's Payment Strategy",
|
| 27 |
+
'content': "Sunday's payment strategy focuses on streamlining the payment process in restaurants through its QR code-based system, enabling customers to pay quickly and efficiently. This approach not only simplifies payments but also reduces processing fees for restaurants through partnerships with companies like Stripe and Checkout.com. The company's payment solution is designed to integrate seamlessly with existing POS systems, offering additional features such as digital menus, order and pay capabilities, and tools to improve online reputation and customer engagement."},
|
| 28 |
+
{'title': 'Partnership with Stripe',
|
| 29 |
+
'content': 'On Wednesday, Sunday announced its new Order & Pay capabilities and its partnership with Stripe, which has allowed it to reduce processing fees for restaurants by an average of 0.5%. Through Stripe, Sunday processes payments via a QR code at the table, which allows customers to pay, split the bill, and add a tip quickly. Sunday is also negotiating with payment providers for better rates.'},
|
| 30 |
+
{'title': 'Stripe Integration and Fee Reduction',
|
| 31 |
+
'content': 'Sunday leverages Stripe to process payments, utilizing a QR code system for customer convenience. This partnership has enabled Sunday to reduce processing fees for restaurants by an average of 0.5%.'},
|
| 32 |
+
{'title': 'Benefits for Restaurants',
|
| 33 |
+
'content': 'Sunday offers numerous benefits for restaurants by leveraging its technology and partnerships, including Stripe. Restaurants using Sunday can experience increased efficiency through streamlined payment processes and order and pay capabilities. Sunday also helps restaurants collect valuable data and analytics, including customer insights and business analytics, which can be used to improve operations and enhance customer experience. Additionally, the integration with Stripe reduces processing fees, benefiting restaurant finances.'},
|
| 34 |
+
{'title': 'Conclusion',
|
| 35 |
+
'content': 'In conclusion, Sunday has significantly impacted restaurant payments by streamlining the payment process and reducing processing fees through its partnership with Stripe. This collaboration has enabled restaurants to enhance efficiency, improve customer experiences, and gain valuable data insights, positioning Sunday as a key player in revolutionizing the restaurant industry.'}],
|
| 36 |
+
'table': {'data': [['App Functionality',
|
| 37 |
+
'Access bill and pay in two clicks, browse the full menu.'],
|
| 38 |
+
['Payment Processing', 'Payments processed by Stripe.'],
|
| 39 |
+
['Stripe Agreement Benefit',
|
| 40 |
+
'Reduced processing fees for restaurants by an average of 0.5%.'],
|
| 41 |
+
['Restaurant Benefits',
|
| 42 |
+
'More tips, more reviews, and more data for restaurants. Increased order and average basket size.'],
|
| 43 |
+
['Point of Sale Integration',
|
| 44 |
+
'Works seamlessly with existing POS systems.'],
|
| 45 |
+
['Payment Method', 'Guests scan a QR code to pay.'],
|
| 46 |
+
['Additional Features',
|
| 47 |
+
'Digital menu with customizable features (scheduling, tags, extras). Options to split the bill and tip.']],
|
| 48 |
+
'columns': ['Aspect', 'Details']},
|
| 49 |
+
'references': 'https://sundayapp.com/2024-recap-how-sunday-revolutionized-restaurant-payments/, https://sundayapp.com/blog/, https://www.checkout.com/case-studies/sunday-brings-the-ease-of-online-payments-to-the-offline-dining-experience'}
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class Research_paper_test:
|
| 53 |
+
test_data:dict
|
| 54 |
+
|
| 55 |
+
test_paper=Research_paper_test
|
| 56 |
+
test_paper.test_data=test_data
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Initialize the agent
|
| 60 |
+
agent = Main_agent()
|
| 61 |
+
|
| 62 |
+
# Initialize session state
|
| 63 |
+
if 'chat_history' not in st.session_state:
|
| 64 |
+
st.session_state.chat_history = []
|
| 65 |
+
|
| 66 |
+
if 'table_data' not in st.session_state:
|
| 67 |
+
st.session_state.table_data = {}
|
| 68 |
+
if 'quick_search_results' not in st.session_state:
|
| 69 |
+
st.session_state.quick_search_results = []
|
| 70 |
+
if 'research_paper' not in st.session_state:
|
| 71 |
+
st.session_state.research_paper = {}
|
| 72 |
+
|
| 73 |
+
if 'agent_memory' not in st.session_state:
|
| 74 |
+
st.session_state.agent_memory = []
|
| 75 |
+
|
| 76 |
+
if 'loop' not in st.session_state:
|
| 77 |
+
st.session_state.loop = asyncio.new_event_loop()
|
| 78 |
+
asyncio.set_event_loop(st.session_state.loop)
|
| 79 |
+
|
| 80 |
+
agent.deps.deep_search_results = st.session_state.research_paper
|
| 81 |
+
agent.deps.quick_search_results = st.session_state.quick_search_results
|
| 82 |
+
agent.deps.table_data = st.session_state.table_data
|
| 83 |
+
agent.memory.messages = st.session_state.agent_memory
|
| 84 |
+
|
| 85 |
+
def paper_to_markdown(paper: dict) -> str:
|
| 86 |
+
"""Convert a paper dictionary into a markdown string."""
|
| 87 |
+
markdown = []
|
| 88 |
+
|
| 89 |
+
# Add title
|
| 90 |
+
markdown.append(f"# {paper.get('title')}")
|
| 91 |
+
markdown.append("\n\n")
|
| 92 |
+
if paper.get('image_url'):
|
| 93 |
+
markdown.append(f"})\n")
|
| 94 |
+
markdown.append("\n\n")
|
| 95 |
+
|
| 96 |
+
# Add body sections
|
| 97 |
+
for section in paper.get('paragraphs', []):
|
| 98 |
+
markdown.append(f"## {section.get('title')}\n")
|
| 99 |
+
if section.get('content'):
|
| 100 |
+
markdown.append(section.get('content'))
|
| 101 |
+
markdown.append("\n")
|
| 102 |
+
markdown.append("\n\n")
|
| 103 |
+
|
| 104 |
+
if paper.get('table'):
|
| 105 |
+
try:
|
| 106 |
+
markdown.append(pd.DataFrame(data=paper.get('table')['data'],
|
| 107 |
+
columns=paper.get('table')['columns']).to_markdown())
|
| 108 |
+
except:
|
| 109 |
+
markdown.append(pd.DataFrame(paper.get('table')).to_markdown())
|
| 110 |
+
markdown.append("\n\n")
|
| 111 |
+
|
| 112 |
+
# Add references
|
| 113 |
+
markdown.append("## References\n")
|
| 114 |
+
markdown.append(str(paper.get('references')))
|
| 115 |
+
|
| 116 |
+
return "".join(markdown)
|
| 117 |
+
|
| 118 |
+
def download_as_docx(paper: dict) -> str:
|
| 119 |
+
"""Convert a paper dictionary into a DOCX file."""
|
| 120 |
+
def temp_image_file(url: str):
|
| 121 |
+
image = requests.get(url)
|
| 122 |
+
image = Image.open(BytesIO(image.content))
|
| 123 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.png', delete=False) as temp_file:
|
| 124 |
+
image.save(temp_file.name)
|
| 125 |
+
image_file_path = temp_file.name
|
| 126 |
+
return image_file_path
|
| 127 |
+
if paper.get('image_url'):
|
| 128 |
+
image_path = temp_image_file(paper.get('image_url'))
|
| 129 |
+
else:
|
| 130 |
+
image_path = None
|
| 131 |
+
|
| 132 |
+
# Prepare markdown
|
| 133 |
+
markdown = []
|
| 134 |
+
|
| 135 |
+
# Add title
|
| 136 |
+
markdown.append(f"# {paper.get('title')}")
|
| 137 |
+
markdown.append("\n\n")
|
| 138 |
+
if image_path:
|
| 139 |
+
markdown.append(f"\n")
|
| 140 |
+
markdown.append("\n\n")
|
| 141 |
+
|
| 142 |
+
# Add body sections
|
| 143 |
+
for section in paper.get('paragraphs', []):
|
| 144 |
+
markdown.append(f"## {section.get('title')}\n")
|
| 145 |
+
if section.get('content'):
|
| 146 |
+
markdown.append(section.get('content'))
|
| 147 |
+
markdown.append("\n")
|
| 148 |
+
markdown.append("\n\n")
|
| 149 |
+
|
| 150 |
+
if paper.get('table'):
|
| 151 |
+
try:
|
| 152 |
+
markdown.append(pd.DataFrame(data=paper.get('table')['data'],
|
| 153 |
+
columns=paper.get('table')['columns']).to_markdown())
|
| 154 |
+
except:
|
| 155 |
+
markdown.append(pd.DataFrame(paper.get('table')).to_markdown())
|
| 156 |
+
markdown.append("\n\n")
|
| 157 |
+
|
| 158 |
+
# Add references
|
| 159 |
+
markdown.append("## References\n")
|
| 160 |
+
markdown.append(str(paper.get('references')))
|
| 161 |
+
|
| 162 |
+
markdown_str = "".join(markdown)
|
| 163 |
+
|
| 164 |
+
# Create a temporary file
|
| 165 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as temp_file:
|
| 166 |
+
temp_file.write(markdown_str)
|
| 167 |
+
temp_file_path = temp_file.name
|
| 168 |
+
|
| 169 |
+
# Create a new document and load from the temporary files
|
| 170 |
+
document = Document()
|
| 171 |
+
document.LoadFromFile(temp_file_path)
|
| 172 |
+
|
| 173 |
+
# Save as DOCX
|
| 174 |
+
output_path = "output.docx"
|
| 175 |
+
document.SaveToFile(output_path, FileFormat.Docx)
|
| 176 |
+
|
| 177 |
+
# Clean up temporary files
|
| 178 |
+
os.unlink(temp_file_path)
|
| 179 |
+
if image_path:
|
| 180 |
+
os.unlink(image_path)
|
| 181 |
+
|
| 182 |
+
return output_path
|
| 183 |
+
|
| 184 |
+
async def process_query(query: str):
|
| 185 |
+
"""Process a user query and update the research paper."""
|
| 186 |
+
response = await agent.chat(query)
|
| 187 |
+
st.session_state.chat_history.append({"role": "user", "content": query})
|
| 188 |
+
st.session_state.chat_history.append({"role": "assistant", "content": str(response)})
|
| 189 |
+
st.session_state.agent_memory = agent.memory.messages
|
| 190 |
+
# Update research paper if available
|
| 191 |
+
if agent.deps.deep_search_results:
|
| 192 |
+
st.session_state.research_paper = agent.deps.deep_search_results
|
| 193 |
+
st.session_state.table_data = agent.deps.table_data
|
| 194 |
+
if agent.deps.table_data:
|
| 195 |
+
st.session_state.table_data = agent.deps.table_data
|
| 196 |
+
if agent.deps.quick_search_results:
|
| 197 |
+
st.session_state.quick_search_results = agent.deps.quick_search_results
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def process_query_sync(query: str):
|
| 201 |
+
"""Synchronous wrapper for async process_query"""
|
| 202 |
+
if not st.session_state.loop.is_running():
|
| 203 |
+
return st.session_state.loop.run_until_complete(process_query(query))
|
| 204 |
+
else:
|
| 205 |
+
# Create a new event loop if the current one is running
|
| 206 |
+
loop = asyncio.new_event_loop()
|
| 207 |
+
asyncio.set_event_loop(loop)
|
| 208 |
+
try:
|
| 209 |
+
return loop.run_until_complete(process_query(query))
|
| 210 |
+
finally:
|
| 211 |
+
loop.close()
|
| 212 |
+
|
| 213 |
+
def reset_chat():
|
| 214 |
+
"""Reset the chat history and agent state."""
|
| 215 |
+
agent.reset()
|
| 216 |
+
st.session_state.chat_history = []
|
| 217 |
+
st.session_state.research_paper = {}
|
| 218 |
+
st.session_state.table_data = {}
|
| 219 |
+
st.session_state.quick_search_results = []
|
| 220 |
+
st.session_state.agent_memory = []
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# Streamlit UI
|
| 225 |
+
st.title("Research Assistant Chatbot")
|
| 226 |
+
|
| 227 |
+
# Create a container for the chat interface
|
| 228 |
+
chat_container = st.container()
|
| 229 |
+
|
| 230 |
+
# Create two columns - one for chat and one for controls
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# Chat interface with custom styling
|
| 235 |
+
st.markdown("""
|
| 236 |
+
<style>
|
| 237 |
+
.stChatMessage {
|
| 238 |
+
padding: 20px;
|
| 239 |
+
border-radius: 10px;
|
| 240 |
+
margin-bottom: 10px;
|
| 241 |
+
}
|
| 242 |
+
.user-message {
|
| 243 |
+
background-color: #e6f3ff;
|
| 244 |
+
text-align: right;
|
| 245 |
+
}
|
| 246 |
+
.assistant-message {
|
| 247 |
+
background-color: #f0f0f0;
|
| 248 |
+
text-align: left;
|
| 249 |
+
}
|
| 250 |
+
</style>
|
| 251 |
+
""", unsafe_allow_html=True)
|
| 252 |
+
|
| 253 |
+
# Chat messages in a scrollable container
|
| 254 |
+
with st.container(height=300):
|
| 255 |
+
for message in st.session_state.chat_history:
|
| 256 |
+
with st.chat_message(message["role"], avatar="🧑" if message["role"] == "user" else "🤖"):
|
| 257 |
+
st.write(message["content"])
|
| 258 |
+
|
| 259 |
+
# Input area at the bottom
|
| 260 |
+
with st.container():
|
| 261 |
+
query = st.text_input("Ask me anything...", key="query_input", placeholder="Type your question here...")
|
| 262 |
+
col3, col4 = st.columns([6, 1])
|
| 263 |
+
with col3:
|
| 264 |
+
if st.button("Submit", use_container_width=True):
|
| 265 |
+
if query:
|
| 266 |
+
with st.spinner('Processing your request...'):
|
| 267 |
+
process_query_sync(query)
|
| 268 |
+
st.rerun()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# Control buttons in the sidebar
|
| 272 |
+
with st.sidebar:
|
| 273 |
+
st.markdown("### Helpful Tips")
|
| 274 |
+
st.markdown("---")
|
| 275 |
+
st.markdown("""reset chat if you want to start over or if the agent is not responding properly""")
|
| 276 |
+
if st.button("Reset Chat", use_container_width=True):
|
| 277 |
+
reset_chat()
|
| 278 |
+
|
| 279 |
+
# Add some helpful tips
|
| 280 |
+
st.markdown("---")
|
| 281 |
+
st.markdown("### Tips")
|
| 282 |
+
st.markdown("""
|
| 283 |
+
- ask the agent about its capabilities and what it can do
|
| 284 |
+
- the table is editable, you can edit in the files section or in the chat
|
| 285 |
+
- the research paper is also editable, you can edit in the chat by asking the agent to edit the research paper
|
| 286 |
+
- the agent can also add a table to the research paper, you can see the table in the files section or in the chat
|
| 287 |
+
""")
|
| 288 |
+
|
| 289 |
+
st.markdown('### Files ')
|
| 290 |
+
|
| 291 |
+
# Research Paper popover
|
| 292 |
+
|
| 293 |
+
if st.session_state.research_paper :
|
| 294 |
+
with st.popover("📄 Research Paper",use_container_width=True):
|
| 295 |
+
st.markdown("""
|
| 296 |
+
<style>
|
| 297 |
+
.paper-preview {
|
| 298 |
+
border: 1px solid #e0e0e0;
|
| 299 |
+
border-radius: 5px;
|
| 300 |
+
padding: 10px;
|
| 301 |
+
margin-bottom: 10px;
|
| 302 |
+
}
|
| 303 |
+
</style>
|
| 304 |
+
""", unsafe_allow_html=True)
|
| 305 |
+
|
| 306 |
+
# Paper content
|
| 307 |
+
st.markdown(paper_to_markdown(st.session_state.research_paper))
|
| 308 |
+
|
| 309 |
+
# Download button
|
| 310 |
+
if st.button("📥 Save as DOCX", use_container_width=True):
|
| 311 |
+
with st.spinner('Preparing document...'):
|
| 312 |
+
output_path = download_as_docx(st.session_state.research_paper)
|
| 313 |
+
with open(output_path, "rb") as file:
|
| 314 |
+
st.download_button(
|
| 315 |
+
label="Download",
|
| 316 |
+
data=file,
|
| 317 |
+
file_name="research_paper.docx",
|
| 318 |
+
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
| 319 |
+
)
|
| 320 |
+
os.unlink(output_path)
|
| 321 |
+
|
| 322 |
+
# Table popover
|
| 323 |
+
if st.session_state.table_data:
|
| 324 |
+
with st.popover("📊 Editable Table",use_container_width=True):
|
| 325 |
+
try:
|
| 326 |
+
try:
|
| 327 |
+
df = pd.DataFrame(data=st.session_state.table_data['data'],
|
| 328 |
+
columns=st.session_state.table_data['columns'])
|
| 329 |
+
except:
|
| 330 |
+
df = pd.DataFrame(st.session_state.table_data)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
edited_df = st.data_editor(df)
|
| 334 |
+
if st.session_state.research_paper:
|
| 335 |
+
if st.button("💾 add table to research paper", use_container_width=True):
|
| 336 |
+
# Convert DataFrame back to the correct structure
|
| 337 |
+
table_dict = {
|
| 338 |
+
'data': edited_df.values.tolist(),
|
| 339 |
+
'columns': edited_df.columns.tolist()
|
| 340 |
+
}
|
| 341 |
+
st.session_state.research_paper['table'] = table_dict
|
| 342 |
+
st.session_state.table_data = table_dict
|
| 343 |
+
st.success("Table updated successfully!")
|
| 344 |
+
st.rerun()
|
| 345 |
+
except Exception as e:
|
| 346 |
+
st.error(f"Error loading table data: {str(e)}")
|
| 347 |
+
|
| 348 |
+
|
deep_research.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_graph import BaseNode, End, GraphRunContext, Graph
|
| 2 |
+
from pydantic_ai import Agent
|
| 3 |
+
from pydantic_ai.common_tools.tavily import tavily_search_tool
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pydantic import Field, BaseModel
|
| 6 |
+
from typing import List, Dict, Optional, Any
|
| 7 |
+
from pydantic_ai.models.gemini import GeminiModel
|
| 8 |
+
from pydantic_ai.providers.google_gla import GoogleGLAProvider
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
import os
|
| 11 |
+
from tavily import TavilyClient
|
| 12 |
+
from IPython.display import Image, display
|
| 13 |
+
import requests
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
load_dotenv()
|
| 17 |
+
google_api_key=os.getenv('google_api_key')
|
| 18 |
+
tavily_key=os.getenv('tavily_key')
|
| 19 |
+
tavily_client = TavilyClient(api_key=tavily_key)
|
| 20 |
+
llm=GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=google_api_key))
|
| 21 |
+
pse=os.getenv('pse')
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class State:
|
| 25 |
+
query:str
|
| 26 |
+
preliminary_research: str
|
| 27 |
+
research_plan: Dict
|
| 28 |
+
research_results: Dict
|
| 29 |
+
validation : str
|
| 30 |
+
final: Dict
|
| 31 |
+
class paragraph_content(BaseModel):
|
| 32 |
+
title: str = Field(description='the title of the paragraph')
|
| 33 |
+
content: str = Field(description='the content of the paragraph')
|
| 34 |
+
|
| 35 |
+
class paragraph(BaseModel):
|
| 36 |
+
title: str = Field(description='the title of the paragraph')
|
| 37 |
+
should_include: str = Field(description='a description of what the paragraph should include')
|
| 38 |
+
class Paper_layout(BaseModel):
|
| 39 |
+
title: str = Field(description='the title of the paper')
|
| 40 |
+
paragraphs: List[paragraph]= Field(description='the list of paragraphs of the paper')
|
| 41 |
+
|
| 42 |
+
paper_layout_agent=Agent(llm, result_type=Paper_layout, system_prompt="generate a paper layout based on the query, preliminary_search, search_results,include a Title for the paper, for the paragraphs only include the title, no content, no image, no table, start with introduction and end with conclusion")
|
| 43 |
+
paragraph_gen_agent=Agent(llm, result_type=paragraph_content, system_prompt="generate a paragraph synthesizing the research_results based on the title,what the paragraph should include, and what has already been written to avoid repetition")
|
| 44 |
+
class PaperGen_node(BaseNode[State]):
|
| 45 |
+
async def run(self, ctx: GraphRunContext[State])->End:
|
| 46 |
+
prompt=(f'query:{ctx.state.query}, preliminary_search:{ctx.state.preliminary_research},search_results:{ctx.state.research_results.research_results}')
|
| 47 |
+
result=await paper_layout_agent.run(prompt)
|
| 48 |
+
paragraphs=[]
|
| 49 |
+
for i in result.data.paragraphs:
|
| 50 |
+
time.sleep(2)
|
| 51 |
+
paragraph_data=await paragraph_gen_agent.run(f'title:{i.title}, should_include:{i.should_include}, research_results:{ctx.state.research_results.research_results}, already_written:{paragraphs}')
|
| 52 |
+
paragraphs.append(paragraph_data.data.model_dump())
|
| 53 |
+
|
| 54 |
+
paper={'title':result.data.title,
|
| 55 |
+
'image_url':ctx.state.research_results.image_url if ctx.state.research_results.image_url else None,
|
| 56 |
+
'paragraphs':paragraphs,
|
| 57 |
+
'table':ctx.state.research_results.table if ctx.state.research_results.table else None,
|
| 58 |
+
'references':ctx.state.research_results.references if ctx.state.research_results.references else None}
|
| 59 |
+
|
| 60 |
+
ctx.state.final=paper
|
| 61 |
+
|
| 62 |
+
return End(ctx.state.final)
|
| 63 |
+
|
| 64 |
+
def google_image_search(query:str):
|
| 65 |
+
"""Search for images using Google Custom Search API
|
| 66 |
+
args: query
|
| 67 |
+
return: image url
|
| 68 |
+
"""
|
| 69 |
+
# Define the API endpoint for Google Custom Search
|
| 70 |
+
url = "https://www.googleapis.com/customsearch/v1"
|
| 71 |
+
|
| 72 |
+
params = {
|
| 73 |
+
"q": query,
|
| 74 |
+
"cx": pse,
|
| 75 |
+
"key": google_api_key,
|
| 76 |
+
"searchType": "image", # Search for images
|
| 77 |
+
"num": 1 # Number of results to fetch
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
# Make the request to the Google Custom Search API
|
| 81 |
+
response = requests.get(url, params=params)
|
| 82 |
+
data = response.json()
|
| 83 |
+
|
| 84 |
+
# Check if the response contains image results
|
| 85 |
+
if 'items' in data:
|
| 86 |
+
# Extract the first image result
|
| 87 |
+
image_url = data['items'][0]['link']
|
| 88 |
+
return image_url
|
| 89 |
+
|
| 90 |
+
class Table_row(BaseModel):
|
| 91 |
+
data: List[str] = Field(description='the data of the row')
|
| 92 |
+
class Table(BaseModel):
|
| 93 |
+
rows: List[Table_row] = Field(description='the rows of the table')
|
| 94 |
+
columns: List[str] = Field(description='the columns of the table')
|
| 95 |
+
|
| 96 |
+
class Research_results(BaseModel):
|
| 97 |
+
research_results: List[str] = Field(default_factory=None,description='the research results')
|
| 98 |
+
image_url: str = Field(default_factory=None,description='the image url if needed else return None')
|
| 99 |
+
table: dict = Field(default_factory=None,description='the table dataframe in a dictionary format')
|
| 100 |
+
references: str = Field(default_factory=None,description='the references (urls) of the research_results')
|
| 101 |
+
|
| 102 |
+
table_agent=Agent(llm, result_type=Table, system_prompt="generate a detailed table in dictionary format based on the research and the query")
|
| 103 |
+
|
| 104 |
+
class Research_node(BaseNode[State]):
|
| 105 |
+
async def run(self, ctx: GraphRunContext[State])->PaperGen_node:
|
| 106 |
+
research_results=Research_results(research_results=[], image_url='', table={}, references='')
|
| 107 |
+
|
| 108 |
+
for i in ctx.state.research_plan.search_queries:
|
| 109 |
+
response = tavily_client.search(i.search_query)
|
| 110 |
+
data=[]
|
| 111 |
+
for i in response.get('results'):
|
| 112 |
+
if i.get('score')>0.50:
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
data.append(i.get('url'))
|
| 116 |
+
research_results.research_results.append(i.get('content'))
|
| 117 |
+
research_results.research_results=list(set(research_results.research_results))
|
| 118 |
+
research_results.references=list(set(data))
|
| 119 |
+
research_results.references=', '.join(research_results.references)
|
| 120 |
+
ctx.state.research_results=research_results
|
| 121 |
+
if ctx.state.research_plan.image_search_query:
|
| 122 |
+
image_url=google_image_search(ctx.state.research_plan.image_search_query)
|
| 123 |
+
ctx.state.research_results.image_url=image_url
|
| 124 |
+
|
| 125 |
+
if ctx.state.research_plan.table:
|
| 126 |
+
result=await table_agent.run(f'research_results:{ctx.state.research_results.research_results},query:{ctx.state.query}')
|
| 127 |
+
ctx.state.research_results.table={'data':[row.data for row in result.data.rows], 'columns':result.data.columns}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
return PaperGen_node()
|
| 131 |
+
|
| 132 |
+
class search_query(BaseModel):
|
| 133 |
+
search_query: str = Field(description='the detailed web search query for the research')
|
| 134 |
+
|
| 135 |
+
class Research_plan(BaseModel):
|
| 136 |
+
search_queries: List[search_query] = Field(description='the detailed web search queries for the research')
|
| 137 |
+
table: Optional[str] = Field(default_factory=None,description='if a table is needed, return yes else return None')
|
| 138 |
+
image_search_query: Optional[str] = Field(default_factory=None,description='if image is needed, generate a image search query, optional')
|
| 139 |
+
|
| 140 |
+
research_plan_agent=Agent(llm, result_type=Research_plan, system_prompt='generate a detailed research plan breaking down the research into smaller parts based on the query and the preliminary search, include a table and image search query if the user wants it')
|
| 141 |
+
|
| 142 |
+
class Research_plan_node(BaseNode[State]):
|
| 143 |
+
async def run(self, ctx: GraphRunContext[State])->Research_node:
|
| 144 |
+
|
| 145 |
+
prompt=(f'query:{ctx.state.query}, preliminary_search:{ctx.state.preliminary_research}')
|
| 146 |
+
result=await research_plan_agent.run(prompt)
|
| 147 |
+
ctx.state.research_plan=result.data
|
| 148 |
+
return Research_node()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
search_agent=Agent(llm, tools=[tavily_search_tool(tavily_key)], system_prompt="do a websearch based on the query")
|
| 152 |
+
|
| 153 |
+
class preliminary_search_node(BaseNode[State]):
|
| 154 |
+
async def run(self, ctx: GraphRunContext[State]) -> Research_plan_node:
|
| 155 |
+
prompt = (' Do a preliminary search to get a global idea of the subject that the user wants to do reseach on as well as the necessary informations to do a search on.\n'
|
| 156 |
+
f'The subject is based on the query: {ctx.state.query}, return the results of the search.')
|
| 157 |
+
result=await search_agent.run(prompt)
|
| 158 |
+
ctx.state.preliminary_research=result.data
|
| 159 |
+
return Research_plan_node()
|
| 160 |
+
|
| 161 |
+
class Deep_research_engine:
|
| 162 |
+
def __init__(self):
|
| 163 |
+
self.graph=Graph(nodes=[preliminary_search_node, Research_plan_node, Research_node, PaperGen_node])
|
| 164 |
+
self.state=State(query='', preliminary_research='', research_plan=[], research_results=[], validation='', final='')
|
| 165 |
+
|
| 166 |
+
async def chat(self,query:str):
|
| 167 |
+
"""Chat with the deep research engine,
|
| 168 |
+
Args:
|
| 169 |
+
query (str): The query to search for
|
| 170 |
+
Returns:
|
| 171 |
+
str: The response from the deep research engine
|
| 172 |
+
"""
|
| 173 |
+
self.state.query=query
|
| 174 |
+
response=await self.graph.run(preliminary_search_node(),state=self.state)
|
| 175 |
+
return response.output
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def display_graph(self):
|
| 179 |
+
"""Display the graph of the deep research engine
|
| 180 |
+
Returns:
|
| 181 |
+
Image: The image of the graph
|
| 182 |
+
"""
|
| 183 |
+
image=self.graph.mermaid_image()
|
| 184 |
+
return display(Image(image))
|
| 185 |
+
|
main_agent.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_ai import Agent, RunContext
|
| 2 |
+
from pydantic_ai.common_tools.tavily import tavily_search_tool
|
| 3 |
+
from pydantic_ai.messages import ModelMessage
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
import os
|
| 6 |
+
from pydantic import Field, BaseModel
|
| 7 |
+
from typing import Dict, List, Any
|
| 8 |
+
from deep_research import Deep_research_engine
|
| 9 |
+
from pydantic_ai.models.gemini import GeminiModel
|
| 10 |
+
from pydantic_ai.providers.google_gla import GoogleGLAProvider
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional
|
| 13 |
+
from spire.doc import Document,FileFormat
|
| 14 |
+
from spire.doc.common import *
|
| 15 |
+
import requests
|
| 16 |
+
from table_maker import table_maker_engine
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from io import BytesIO, StringIO
|
| 19 |
+
import tempfile
|
| 20 |
+
import pandas as pd
|
| 21 |
+
|
| 22 |
+
load_dotenv()
|
| 23 |
+
tavily_key=os.getenv('tavily_key')
|
| 24 |
+
google_api_key=os.getenv('google_api_key')
|
| 25 |
+
pse=os.getenv('pse')
|
| 26 |
+
|
| 27 |
+
llm=GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=google_api_key))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class Deps:
|
| 32 |
+
deep_search_results:dict
|
| 33 |
+
quick_search_results:list[str]
|
| 34 |
+
table_data:dict
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
async def deep_research_agent(ctx:RunContext[Deps], query:str):
|
| 40 |
+
"""
|
| 41 |
+
This function is used to do a deep research on the web for information on a complex query, generates a report or a paper.
|
| 42 |
+
Args:
|
| 43 |
+
query (str): The query to search for
|
| 44 |
+
Returns:
|
| 45 |
+
str: The result of the search
|
| 46 |
+
"""
|
| 47 |
+
deepsearch=Deep_research_engine()
|
| 48 |
+
res=await deepsearch.chat(query)
|
| 49 |
+
ctx.deps.deep_search_results=res
|
| 50 |
+
ctx.deps.table_data=res.get('table')
|
| 51 |
+
return str(res)
|
| 52 |
+
|
| 53 |
+
quick_search_agent=Agent(llm,tools=[tavily_search_tool(tavily_key)])
|
| 54 |
+
async def quick_research_agent(ctx: RunContext[Deps], query:str):
|
| 55 |
+
"""
|
| 56 |
+
This function is used to do a quick search on the web for information on a given query.
|
| 57 |
+
Args:
|
| 58 |
+
query (str): The query to search for
|
| 59 |
+
Returns:
|
| 60 |
+
str: The result of the search
|
| 61 |
+
"""
|
| 62 |
+
res=await quick_search_agent.run(query)
|
| 63 |
+
ctx.deps.quick_search_results.append(res.data)
|
| 64 |
+
return str(res.data)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def google_image_search(query:str):
|
| 68 |
+
"""Search for images using Google Custom Search API
|
| 69 |
+
args: query
|
| 70 |
+
return: image url
|
| 71 |
+
"""
|
| 72 |
+
# Define the API endpoint for Google Custom Search
|
| 73 |
+
url = "https://www.googleapis.com/customsearch/v1"
|
| 74 |
+
|
| 75 |
+
params = {
|
| 76 |
+
"q": query,
|
| 77 |
+
"cx": pse,
|
| 78 |
+
"key": google_api_key,
|
| 79 |
+
"searchType": "image", # Search for images
|
| 80 |
+
"num": 1 # Number of results to fetch
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
# Make the request to the Google Custom Search API
|
| 84 |
+
response = requests.get(url, params=params)
|
| 85 |
+
data = response.json()
|
| 86 |
+
|
| 87 |
+
# Check if the response contains image results
|
| 88 |
+
if 'items' in data:
|
| 89 |
+
# Extract the first image result
|
| 90 |
+
image_url = data['items'][0]['link']
|
| 91 |
+
return image_url
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
async def research_editor_tool(ctx: RunContext[Deps], query:str):
|
| 96 |
+
"""
|
| 97 |
+
Use this tool to edit the deep search result to make it more accurate following the query's instructions.
|
| 98 |
+
This tool can modify paragraphs, image_url. For image_url, you need to give the query to search for the image.
|
| 99 |
+
Args:
|
| 100 |
+
query (str): The query containing instructions for editing the deep search result
|
| 101 |
+
Returns:
|
| 102 |
+
str: The edited and improved deep search result
|
| 103 |
+
"""
|
| 104 |
+
@dataclass
|
| 105 |
+
class edit_route:
|
| 106 |
+
paragraph_number:Optional[int] = Field(default_factory=None, description='the number of the paragraph to edit, if the paragraph is not needed to be edited, return None')
|
| 107 |
+
route: str = Field(description='the route to the content to edit, either paragraphs, image_url')
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
paper_dict={'title':ctx.deps.deep_search_results.get('title'),
|
| 112 |
+
'image_url':ctx.deps.deep_search_results.get('image_url') if ctx.deps.deep_search_results.get('image_url') else 'None',
|
| 113 |
+
'paragraphs_title':{num:paragraph.get('title') for num,paragraph in enumerate(ctx.deps.deep_search_results.get('paragraphs'))},
|
| 114 |
+
'table':ctx.deps.deep_search_results.get('table') if ctx.deps.deep_search_results.get('table') else 'None',
|
| 115 |
+
'references':ctx.deps.deep_search_results.get('references')}
|
| 116 |
+
|
| 117 |
+
route_agent=Agent(llm,result_type=edit_route, system_prompt="you decide the route to the content to edit based on the query's instructions and the paper_dict, either paragraphs, image_url")
|
| 118 |
+
route=await route_agent.run(f'query:{query}, paper_dict:{paper_dict}')
|
| 119 |
+
contents=ctx.deps.deep_search_results
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@dataclass
|
| 123 |
+
class Research_edits:
|
| 124 |
+
edits:str = Field(description='the edits')
|
| 125 |
+
editor_agent=Agent(llm,tools=[google_image_search],result_type=Research_edits, system_prompt="you are an editor, you are given a query, some content to edit, and maybe a quick search result (optional), you need to edit the deep search result to make it more accurate following the query's instructions, return only the edited content, no comments")
|
| 126 |
+
if route.data.route=='paragraphs':
|
| 127 |
+
content=contents.get('paragraphs')[route.data.paragraph_number]['content']
|
| 128 |
+
res=await editor_agent.run(f'query:{query}, content:{content}, quick_search_results:{ctx.deps.quick_search_results if ctx.deps.quick_search_results else "None"}')
|
| 129 |
+
ctx.deps.deep_search_results['paragraphs'][route.data.paragraph_number]['content']=res.data.edits
|
| 130 |
+
if route.data.route=='image_url':
|
| 131 |
+
content=contents.get('image_url')
|
| 132 |
+
res=await editor_agent.run(f'query:{query}, content:{content}, quick_search_results:{ctx.deps.quick_search_results if ctx.deps.quick_search_results else "None"}')
|
| 133 |
+
ctx.deps.deep_search_results['image_url']=res.data.edits
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
return str(ctx.deps.deep_search_results)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
async def Table_agent(ctx: RunContext[Deps], query:str):
|
| 140 |
+
"""
|
| 141 |
+
Use this tool to create a table, edit a table or add a table to the deep search result. the add table to paper route is used to create and add a table to the deep search result.
|
| 142 |
+
Args:
|
| 143 |
+
query (str): The query to create a table, edit a table or add a table to the deep search result
|
| 144 |
+
Returns:
|
| 145 |
+
dict: The table
|
| 146 |
+
"""
|
| 147 |
+
@dataclass
|
| 148 |
+
class route:
|
| 149 |
+
route: str = Field(description='the route to the content to edit, either create_table, edit_table, or add_table_to_paper')
|
| 150 |
+
route_agent=Agent(llm,result_type=route, system_prompt="you decide the route to the content to edit based on the query's instructions, return only the route, either create_table, edit_table, or add_table_to_paper")
|
| 151 |
+
route=await route_agent.run(f'query:{query}')
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if route.data.route=='create_table':
|
| 155 |
+
table_maker=table_maker_engine()
|
| 156 |
+
table=await table_maker.chat(query)
|
| 157 |
+
ctx.deps.table_data=table
|
| 158 |
+
return str(table)
|
| 159 |
+
|
| 160 |
+
if route.data.route=='edit_table':
|
| 161 |
+
table=ctx.deps.table_data
|
| 162 |
+
class Table_row(BaseModel):
|
| 163 |
+
data: List[str] = Field(description='the data of the row')
|
| 164 |
+
class Table(BaseModel):
|
| 165 |
+
rows: List[Table_row] = Field(description='the rows of the table')
|
| 166 |
+
columns: List[str] = Field(description='the columns of the table')
|
| 167 |
+
|
| 168 |
+
table_editor=Agent(llm, result_type=Table, system_prompt="edit the table based on the query's instructions, the research results (if any) and the quick search results(if any)")
|
| 169 |
+
generated_table=await table_editor.run(f'query:{query}, table:{table}, research:{ctx.deps.deep_search_results if ctx.deps.deep_search_results else "None"}, quick_search_results:{ctx.deps.quick_search_results if ctx.deps.quick_search_results else "None"}')
|
| 170 |
+
ctx.deps.table_data={'data':[row.data for row in generated_table.data.rows], 'columns':generated_table.data.columns}
|
| 171 |
+
return str(ctx.deps.table_data)
|
| 172 |
+
|
| 173 |
+
if route.data.route=='add_table_to_paper':
|
| 174 |
+
class Table_row(BaseModel):
|
| 175 |
+
data: List[str] = Field(description='the data of the row')
|
| 176 |
+
class Table(BaseModel):
|
| 177 |
+
rows: List[Table_row] = Field(description='the rows of the table')
|
| 178 |
+
columns: List[str] = Field(description='the columns of the table')
|
| 179 |
+
table_creator=Agent(llm, result_type=Table, system_prompt="create a table based on the query's instructions, the research results (if any) and the quick search results(if any)")
|
| 180 |
+
generated_table=await table_creator.run(f'query:{query}, research:{ctx.deps.deep_search_results if ctx.deps.deep_search_results else "None"}, quick_search_results:{ctx.deps.quick_search_results if ctx.deps.quick_search_results else "None"}')
|
| 181 |
+
ctx.deps.deep_search_results['table']={'data':[row.data for row in generated_table.data.rows], 'columns':generated_table.data.columns}
|
| 182 |
+
ctx.deps.table_data=ctx.deps.deep_search_results['table']
|
| 183 |
+
return str(ctx.deps.deep_search_results)
|
| 184 |
+
|
| 185 |
+
@dataclass
|
| 186 |
+
class Message_state:
|
| 187 |
+
messages: list[ModelMessage]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class Main_agent:
|
| 192 |
+
def __init__(self):
|
| 193 |
+
self.agent=Agent(llm, system_prompt="you are a research assistant, you are given a query, leverage what tool(s) to use, make suggestions to the user about the tools to use, \
|
| 194 |
+
never show the output of the tools, except for the table, notify the user about what next step they can take, inform the user about the table,\
|
| 195 |
+
and the table's editable nature either in the chat or in the files section",
|
| 196 |
+
tools=[deep_research_agent,research_editor_tool,quick_research_agent,Table_agent])
|
| 197 |
+
self.deps=Deps( deep_search_results=[], quick_search_results=[], table_data={})
|
| 198 |
+
self.memory=Message_state(messages=[])
|
| 199 |
+
|
| 200 |
+
async def chat(self, query:str):
|
| 201 |
+
result = await self.agent.run(query,deps=self.deps, message_history=self.memory.messages)
|
| 202 |
+
self.memory.messages=result.all_messages()
|
| 203 |
+
return result.data
|
| 204 |
+
|
| 205 |
+
def reset(self):
|
| 206 |
+
self.memory.messages=[]
|
| 207 |
+
self.deps=Deps( deep_search_results=[], quick_search_results=[])
|
| 208 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ipython==9.0.2
|
| 2 |
+
Pillow==11.1.0
|
| 3 |
+
pydantic==2.11.2
|
| 4 |
+
pydantic_ai==0.0.52
|
| 5 |
+
pydantic_graph==0.0.52
|
| 6 |
+
python-dotenv==1.1.0
|
| 7 |
+
Requests==2.32.3
|
| 8 |
+
Spire.Doc.Free==12.12.0
|
| 9 |
+
tavily_python==0.5.1
|
| 10 |
+
tabulate==0.9.0
|
| 11 |
+
pandas==2.1.1
|
| 12 |
+
python-pptx
|
| 13 |
+
streamlit==1.26.0
|
| 14 |
+
nest_asyncio==1.6.0
|
table_maker.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_graph import BaseNode, End, GraphRunContext, Graph
|
| 2 |
+
from pydantic_ai import Agent
|
| 3 |
+
from pydantic_ai.common_tools.tavily import tavily_search_tool
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pydantic import Field, BaseModel
|
| 6 |
+
from typing import List, Dict, Optional, Any
|
| 7 |
+
from pydantic_ai.models.gemini import GeminiModel
|
| 8 |
+
from pydantic_ai.providers.google_gla import GoogleGLAProvider
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
import os
|
| 11 |
+
from tavily import TavilyClient
|
| 12 |
+
from IPython.display import Image, display
|
| 13 |
+
import requests
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
google_api_key=os.getenv('google_api_key')
|
| 17 |
+
tavily_key=os.getenv('tavily_key')
|
| 18 |
+
tavily_client = TavilyClient(api_key=tavily_key)
|
| 19 |
+
llm=GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=google_api_key))
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class State:
|
| 23 |
+
query:str
|
| 24 |
+
research:List[str]
|
| 25 |
+
table:dict
|
| 26 |
+
preliminary_research:str
|
| 27 |
+
research_plan:List[str]
|
| 28 |
+
|
| 29 |
+
#define the table row and table schema
|
| 30 |
+
class Table_row(BaseModel):
|
| 31 |
+
data: List[str] = Field(description='the data of the row')
|
| 32 |
+
class Table(BaseModel):
|
| 33 |
+
rows: List[Table_row] = Field(description='the rows of the table')
|
| 34 |
+
columns: List[str] = Field(description='the columns of the table')
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class table_maker_node(BaseNode[State]):
|
| 39 |
+
async def run(self, ctx: GraphRunContext[State])->End:
|
| 40 |
+
table_agent=Agent(llm, result_type=Table, system_prompt="generate a detailed table in a dictionary format based on the research and the query")
|
| 41 |
+
table=await table_agent.run(f'query:{ctx.state.query}, research:{ctx.state.research}')
|
| 42 |
+
ctx.state.table={'data':[row.data for row in table.data.rows], 'columns':table.data.columns}
|
| 43 |
+
return End(ctx.state.table)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class data_research_node(BaseNode[State]):
|
| 47 |
+
async def run(self, ctx: GraphRunContext[State])->table_maker_node:
|
| 48 |
+
for i in ctx.state.research_plan:
|
| 49 |
+
response = tavily_client.search(i.search_query)
|
| 50 |
+
|
| 51 |
+
for i in response.get('results'):
|
| 52 |
+
if i.get('score')>0.50:
|
| 53 |
+
ctx.state.research.append(i.get('content'))
|
| 54 |
+
return table_maker_node()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class search_query(BaseModel):
|
| 59 |
+
search_query: str = Field(description='the detailed web search query for the research')
|
| 60 |
+
|
| 61 |
+
class Research_plan(BaseModel):
|
| 62 |
+
search_queries: List[search_query] = Field(description='the detailed web search queries for the research')
|
| 63 |
+
|
| 64 |
+
research_plan_agent=Agent(llm, result_type=Research_plan, system_prompt='generate a detailed research plan breaking down the research into smaller parts based on the query and the preliminary search')
|
| 65 |
+
|
| 66 |
+
class Research_plan_node(BaseNode[State]):
|
| 67 |
+
async def run(self, ctx: GraphRunContext[State])->data_research_node:
|
| 68 |
+
|
| 69 |
+
prompt=(f'query:{ctx.state.query}, preliminary_search:{ctx.state.preliminary_research}')
|
| 70 |
+
result=await research_plan_agent.run(prompt)
|
| 71 |
+
ctx.state.research_plan=result.data.search_queries
|
| 72 |
+
return data_research_node()
|
| 73 |
+
|
| 74 |
+
class preliminary_search_node(BaseNode[State]):
|
| 75 |
+
async def run(self, ctx: GraphRunContext[State]) -> Research_plan_node:
|
| 76 |
+
search_agent=Agent(llm, tools=[tavily_search_tool(tavily_key)], system_prompt="do a websearch based on the query")
|
| 77 |
+
prompt = (' Do a preliminary search to get a global idea of the subject that the user wants to do reseach on as well as the necessary informations to do a search on.\n'
|
| 78 |
+
f'The subject is based on the query: {ctx.state.query}, return the results of the search.')
|
| 79 |
+
result=await search_agent.run(prompt)
|
| 80 |
+
ctx.state.preliminary_research=result.data
|
| 81 |
+
return Research_plan_node()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class table_maker_engine:
|
| 85 |
+
def __init__(self):
|
| 86 |
+
self.graph=Graph(nodes=[preliminary_search_node, Research_plan_node, data_research_node, table_maker_node])
|
| 87 |
+
self.state=State(query='', research=[], table={}, preliminary_research='', research_plan=[])
|
| 88 |
+
|
| 89 |
+
async def chat(self,query:str):
|
| 90 |
+
"""Chat with the table maker engine,
|
| 91 |
+
Args:
|
| 92 |
+
query (str): The query to search for
|
| 93 |
+
Returns:
|
| 94 |
+
str: The response from the table maker engine
|
| 95 |
+
"""
|
| 96 |
+
self.state.query=query
|
| 97 |
+
response=await self.graph.run(preliminary_search_node(),state=self.state)
|
| 98 |
+
return response.output
|
| 99 |
+
|
| 100 |
+
def display_graph(self):
|
| 101 |
+
"""Display the graph of the table maker engine
|
| 102 |
+
Returns:
|
| 103 |
+
Image: The image of the graph
|
| 104 |
+
"""
|
| 105 |
+
image=self.graph.mermaid_image()
|
| 106 |
+
return display(Image(image))
|