Soham Waghmare commited on
Commit
88139f0
·
1 Parent(s): 501bdbe

fix: types, off-by-one graph sraping logic, pull depth, breadth, num_sites

Browse files
backend/app.py CHANGED
@@ -30,7 +30,7 @@ app.add_middleware(
30
  allow_headers=["*"],
31
  )
32
 
33
- sio = socketio.AsyncServer(cors_allowed_origins=CORS_ALLOWED_ORIGINS, ping_timeout=60, ping_interval=10, async_mode="asgi")
34
  app.mount("/", socketio.ASGIApp(sio))
35
 
36
 
@@ -76,11 +76,15 @@ async def health_check(sid, data):
76
 
77
  @sio.event
78
  async def start_research(sid, data):
79
- knet, scraper = await session_manager.get_or_create_session(sid)
80
-
81
  try:
82
  data = json.loads(data) if type(data) != dict else data
83
  topic = data.get("topic")
 
 
 
 
 
 
84
  session_id = sid
85
  logger.info(f"Starting research for client {session_id} on topic: {topic}")
86
 
@@ -96,7 +100,7 @@ async def start_research(sid, data):
96
  logger.error(f"Error in progress callback: {str(e)}")
97
  raise e
98
 
99
- research_results = await knet.conduct_research(topic, progress_callback)
100
  logger.info(f"Research completed for topic: {topic}")
101
  await sio.emit("research_complete", research_results, room=session_id)
102
 
@@ -107,7 +111,7 @@ async def start_research(sid, data):
107
 
108
  @sio.event
109
  async def test(sid, data):
110
- knet, scraper = await session_manager.get_or_create_session(sid)
111
  print("Testing...")
112
  data = json.loads(data) if type(data) != dict else data
113
  res = await knet.scraper._scrape_page(data["url"])
 
30
  allow_headers=["*"],
31
  )
32
 
33
+ sio = socketio.AsyncServer(cors_allowed_origins=CORS_ALLOWED_ORIGINS, ping_timeout=120, ping_interval=10, async_mode="asgi")
34
  app.mount("/", socketio.ASGIApp(sio))
35
 
36
 
 
76
 
77
  @sio.event
78
  async def start_research(sid, data):
 
 
79
  try:
80
  data = json.loads(data) if type(data) != dict else data
81
  topic = data.get("topic")
82
+ max_depth: int = data.get("max_depth")
83
+ max_breadth: int = data.get("max_breadth")
84
+ num_sites_per_query: int = data.get("num_sites_per_query")
85
+
86
+ knet, _ = await session_manager.get_or_create_session(sid)
87
+
88
  session_id = sid
89
  logger.info(f"Starting research for client {session_id} on topic: {topic}")
90
 
 
100
  logger.error(f"Error in progress callback: {str(e)}")
101
  raise e
102
 
103
+ research_results = await knet.conduct_research(topic, progress_callback, max_depth, max_breadth, num_sites_per_query)
104
  logger.info(f"Research completed for topic: {topic}")
105
  await sio.emit("research_complete", research_results, room=session_id)
106
 
 
111
 
112
  @sio.event
113
  async def test(sid, data):
114
+ knet, _ = await session_manager.get_or_create_session(sid)
115
  print("Testing...")
116
  data = json.loads(data) if type(data) != dict else data
117
  res = await knet.scraper._scrape_page(data["url"])
backend/knet.py CHANGED
@@ -18,7 +18,7 @@ load_dotenv()
18
 
19
 
20
  class ResearchProgress:
21
- def __init__(self, callback=None):
22
  self.progress = 0
23
  self.callback = callback
24
 
@@ -31,7 +31,7 @@ class ResearchProgress:
31
 
32
 
33
  class KNet:
34
- def __init__(self, scraper_instance=None):
35
  self.api_key = os.getenv("GOOGLE_API_KEY")
36
  assert self.api_key, "Google API key is required"
37
 
@@ -80,9 +80,9 @@ class KNet:
80
  self.scraper = scraper_instance
81
 
82
  self.logger = logging.getLogger(__name__)
83
- self.max_depth = 2
84
- self.max_breadth = 3
85
- self.num_sites_per_query = 5
86
 
87
  self.search_prompt = """Generate 3-5 specific search queries to research the following topic: {topic}
88
 
@@ -147,13 +147,14 @@ class KNet:
147
  def _track_tokens(self, tokens: int) -> None:
148
  self.token_count += tokens
149
 
150
- def _should_branch_deeper(self, node: ResearchNode, topic: str, retry_count=0) -> bool:
151
  try:
 
 
 
152
  # Generate summary of key findings into research_manager's context
153
  if node.data:
154
- findings = ("\n" + "-" * 10 + "Next data" + "-" * 10 + "\n").join(
155
- [json.dumps(d, indent=2) for d in node.data]
156
- )
157
  response = self.llm.generate_content(
158
  f"Extract key findings from the following data related to the topic '{topic}':\n{findings}"
159
  )
@@ -181,7 +182,13 @@ class KNet:
181
  self.logger.error(f"Branch decision failed: {str(e)}")
182
  raise e
183
 
184
- async def conduct_research(self, topic: str, progress_callback=None) -> Dict[str, Any]:
 
 
 
 
 
 
185
  self.ctx_researcher = []
186
  self.ctx_manager = []
187
  self.token_count = 0
@@ -198,7 +205,7 @@ class KNet:
198
  while to_explore:
199
  current_node, current_depth = to_explore.popleft()
200
 
201
- if current_node.query in explored_queries or current_depth >= self.max_depth:
202
  continue
203
 
204
  self.logger.info(f"Exploring: {current_node.query} (Depth: {current_depth})")
@@ -223,12 +230,10 @@ class KNet:
223
  await progress.update(30, "Generating comprehensive report...")
224
  final_report = self._generate_final_report(root_node)
225
 
226
- self.logger.info(
227
- f"Research completed. Explored {len(explored_queries)} queries across {root_node.max_depth()} levels"
228
- )
229
  await progress.update(100, "Research complete!")
230
 
231
- with open("output.json", "a") as f:
232
  json.dump(final_report, f, indent=2)
233
  return final_report
234
 
@@ -236,9 +241,9 @@ class KNet:
236
  self.logger.error(f"Research failed: {str(e)}")
237
  raise e
238
 
239
- def _analyze_and_branch(self, node: ResearchNode, topic: str, retry_count=0) -> List[ResearchNode]:
240
  try:
241
- if not node.data:
242
  return []
243
 
244
  analysis_prompt = dedent(
@@ -255,9 +260,7 @@ class KNet:
255
  - query (string)"""
256
  )
257
 
258
- response = self.research_manager.generate_content(
259
- analysis_prompt, generation_config={**self.analysis_schema}
260
- )
261
  self._track_tokens(response.usage_metadata.total_token_count)
262
  result = json.loads(response.text)
263
  self.logger.info(f"New branches for '{node.query}': {result['branches']}")
@@ -279,7 +282,7 @@ class KNet:
279
  self.logger.error(f"Branch analysis failed: {str(e)}")
280
  raise e
281
 
282
- def _generate_final_report(self, root_node: ResearchNode, retry_count=0) -> Dict[str, Any]:
283
  try:
284
  findings = "\n".join(self.ctx_manager)
285
  with open("output.json", "w") as f:
@@ -310,9 +313,12 @@ class KNet:
310
  def build_tree_structure(node: ResearchNode) -> Dict:
311
  if not node:
312
  return {}
 
 
313
  return {
314
  "query": node.query,
315
  "depth": node.depth,
 
316
  "children": [build_tree_structure(child) for child in node.children],
317
  }
318
 
 
18
 
19
 
20
  class ResearchProgress:
21
+ def __init__(self, callback):
22
  self.progress = 0
23
  self.callback = callback
24
 
 
31
 
32
 
33
  class KNet:
34
+ def __init__(self, scraper_instance, max_depth: int = 1, max_breadth: int = 1, num_sites_per_query: int = 5):
35
  self.api_key = os.getenv("GOOGLE_API_KEY")
36
  assert self.api_key, "Google API key is required"
37
 
 
80
  self.scraper = scraper_instance
81
 
82
  self.logger = logging.getLogger(__name__)
83
+ self.max_depth = max_depth
84
+ self.max_breadth = max_breadth
85
+ self.num_sites_per_query = num_sites_per_query
86
 
87
  self.search_prompt = """Generate 3-5 specific search queries to research the following topic: {topic}
88
 
 
147
  def _track_tokens(self, tokens: int) -> None:
148
  self.token_count += tokens
149
 
150
+ def _should_branch_deeper(self, node: ResearchNode, topic: str, retry_count: int = 0) -> bool:
151
  try:
152
+ if node.depth > self.max_depth:
153
+ return False
154
+
155
  # Generate summary of key findings into research_manager's context
156
  if node.data:
157
+ findings = ("\n" + "-" * 10 + "Next data" + "-" * 10 + "\n").join([json.dumps(d, indent=2) for d in node.data])
 
 
158
  response = self.llm.generate_content(
159
  f"Extract key findings from the following data related to the topic '{topic}':\n{findings}"
160
  )
 
182
  self.logger.error(f"Branch decision failed: {str(e)}")
183
  raise e
184
 
185
+ async def conduct_research(
186
+ self, topic: str, progress_callback, max_depth: int, max_breadth: int, num_sites_per_query: int
187
+ ) -> Dict[str, Any]:
188
+ self.max_depth = max_depth
189
+ self.max_breadth = max_breadth
190
+ self.num_sites_per_query = num_sites_per_query
191
+
192
  self.ctx_researcher = []
193
  self.ctx_manager = []
194
  self.token_count = 0
 
205
  while to_explore:
206
  current_node, current_depth = to_explore.popleft()
207
 
208
+ if current_node.query in explored_queries or current_depth > self.max_depth:
209
  continue
210
 
211
  self.logger.info(f"Exploring: {current_node.query} (Depth: {current_depth})")
 
230
  await progress.update(30, "Generating comprehensive report...")
231
  final_report = self._generate_final_report(root_node)
232
 
233
+ self.logger.info(f"Research completed. Explored {len(explored_queries)} queries across {root_node.max_depth()} levels")
 
 
234
  await progress.update(100, "Research complete!")
235
 
236
+ with open("output.json", "a", encoding="utf-8") as f:
237
  json.dump(final_report, f, indent=2)
238
  return final_report
239
 
 
241
  self.logger.error(f"Research failed: {str(e)}")
242
  raise e
243
 
244
+ def _analyze_and_branch(self, node: ResearchNode, topic: str, retry_count: int = 0) -> List[ResearchNode]:
245
  try:
246
+ if not node.data or node.depth > self.max_depth:
247
  return []
248
 
249
  analysis_prompt = dedent(
 
260
  - query (string)"""
261
  )
262
 
263
+ response = self.research_manager.generate_content(analysis_prompt, generation_config={**self.analysis_schema})
 
 
264
  self._track_tokens(response.usage_metadata.total_token_count)
265
  result = json.loads(response.text)
266
  self.logger.info(f"New branches for '{node.query}': {result['branches']}")
 
282
  self.logger.error(f"Branch analysis failed: {str(e)}")
283
  raise e
284
 
285
+ def _generate_final_report(self, root_node: ResearchNode, retry_count: int = 0) -> Dict[str, Any]:
286
  try:
287
  findings = "\n".join(self.ctx_manager)
288
  with open("output.json", "w") as f:
 
313
  def build_tree_structure(node: ResearchNode) -> Dict:
314
  if not node:
315
  return {}
316
+
317
+ sources = [d["url"] for d in node.data if d.get("url")]
318
  return {
319
  "query": node.query,
320
  "depth": node.depth,
321
+ "sources": sources,
322
  "children": [build_tree_structure(child) for child in node.children],
323
  }
324
 
backend/research_node.py CHANGED
@@ -1,5 +1,6 @@
1
  from datetime import datetime
2
  from typing import Any, Dict, List, Optional
 
3
 
4
 
5
  class ResearchNode:
@@ -34,7 +35,7 @@ class ResearchNode:
34
  return len(self.children) + sum([child.total_children() for child in self.children])
35
 
36
  def get_all_data(self) -> List[Dict[str, Any]]:
37
- data = self.data
38
  for child in self.children:
39
  data.extend(child.get_all_data())
40
  return data
 
1
  from datetime import datetime
2
  from typing import Any, Dict, List, Optional
3
+ import copy
4
 
5
 
6
  class ResearchNode:
 
35
  return len(self.children) + sum([child.total_children() for child in self.children])
36
 
37
  def get_all_data(self) -> List[Dict[str, Any]]:
38
+ data = copy.deepcopy(self.data)
39
  for child in self.children:
40
  data.extend(child.get_all_data())
41
  return data
backend/scraper.py CHANGED
@@ -185,7 +185,6 @@ class CrawlForAIScraper:
185
 
186
  # Perform a search to get a list of webpages
187
  search_results = await self._search(query, num_sites)
188
- self.logger.info(f"Found {len(search_results)} search results")
189
 
190
  # Scrape each webpage
191
  scraped_data = []
@@ -219,11 +218,9 @@ class CrawlForAIScraper:
219
  if not url.startswith(("http://", "https://")):
220
  url = "https://" + url
221
  search_results.append(url)
222
- if len(search_results) >= num_results:
223
- break
224
 
225
- self.logger.info(f"Found {len(search_results)} URLs")
226
- return search_results
227
 
228
  except requests.exceptions.RequestException as e:
229
  self.logger.error(f"Google search error: {str(e)}")
 
185
 
186
  # Perform a search to get a list of webpages
187
  search_results = await self._search(query, num_sites)
 
188
 
189
  # Scrape each webpage
190
  scraped_data = []
 
218
  if not url.startswith(("http://", "https://")):
219
  url = "https://" + url
220
  search_results.append(url)
 
 
221
 
222
+ self.logger.info(f"Found {len(search_results)} results.")
223
+ return search_results[:num_results]
224
 
225
  except requests.exceptions.RequestException as e:
226
  self.logger.error(f"Google search error: {str(e)}")