Soham Waghmare commited on
Commit
73fba58
·
1 Parent(s): ac03e8a

feat: migrate from BFS to DFS for working with research_plan

Browse files
Files changed (4) hide show
  1. .gitignore +1 -1
  2. backend/app.py +3 -8
  3. backend/knet.py +176 -90
  4. backend/research_node.py +6 -6
.gitignore CHANGED
@@ -9,7 +9,7 @@ backend/*.pyo
9
  backend/.venv/
10
  backend/.env*
11
  backend/downloads/*
12
- backend/output.json
13
  backend/.ruff_cache/
14
 
15
  # Next.js ignore files
 
9
  backend/.venv/
10
  backend/.env*
11
  backend/downloads/*
12
+ backend/*.log.*
13
  backend/.ruff_cache/
14
 
15
  # Next.js ignore files
backend/app.py CHANGED
@@ -86,7 +86,6 @@ async def start_research(sid, data):
86
  data = json.loads(data) if type(data) is not dict else data
87
  topic = data.get("topic").strip()
88
  max_depth: int = data.get("max_depth")
89
- max_breadth: int = data.get("max_breadth")
90
  num_sites_per_query: int = data.get("num_sites_per_query")
91
 
92
  knet, _ = await session_manager.get_or_create_session(sid)
@@ -94,14 +93,10 @@ async def start_research(sid, data):
94
  session_id = sid
95
  logger.info(f"Starting research for client {session_id}.\nTopic '{topic}'")
96
 
97
- async def progress_callback(status):
98
- await sio.emit(
99
- "status",
100
- {"message": status["message"], "progress": status["progress"]},
101
- room=session_id,
102
- )
103
 
104
- research_results = await knet.conduct_research(topic, progress_callback, max_depth, max_breadth, num_sites_per_query)
105
  logger.info(f"Research completed for topic: {topic}")
106
  await sio.emit("research_complete", research_results, room=session_id)
107
 
 
86
  data = json.loads(data) if type(data) is not dict else data
87
  topic = data.get("topic").strip()
88
  max_depth: int = data.get("max_depth")
 
89
  num_sites_per_query: int = data.get("num_sites_per_query")
90
 
91
  knet, _ = await session_manager.get_or_create_session(sid)
 
93
  session_id = sid
94
  logger.info(f"Starting research for client {session_id}.\nTopic '{topic}'")
95
 
96
+ async def progress_callback(status: dict):
97
+ await sio.emit("status", status, room=session_id)
 
 
 
 
98
 
99
+ research_results = await knet.conduct_research(topic, progress_callback, max_depth, num_sites_per_query)
100
  logger.info(f"Research completed for topic: {topic}")
101
  await sio.emit("research_complete", research_results, room=session_id)
102
 
backend/knet.py CHANGED
@@ -19,27 +19,50 @@ load_dotenv()
19
 
20
  class Prompt:
21
  def __init__(self) -> None:
22
- self.continue_branch = dedent("""Given the current research context and findings, should we explore this branch deeper?
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  Current Topic: {query}
25
- Current Depth: {depth}
26
- Path from Root: {path}
27
- Key Findings:
28
- {findings}
 
29
 
30
  Consider:
31
- 1. Relevance to main topic
32
- 2. Potential for new insights
33
- 3. Depth vs breadth tradeoff
34
- 4. Information saturation
35
 
36
- Return only: decision: true/false""")
37
 
38
- self.search_query = dedent("""Based on the following findings about "{topic}", suggest new research directions.
39
- Findings:
 
 
 
 
 
 
40
  {ctx_manager}
41
 
42
- Suggest up to {n} specific google search queries that would help data which:
 
43
  - Builds upon these findings
44
  - Explores different aspects
45
  - Goes deeper into important details
@@ -47,34 +70,67 @@ class Prompt:
47
  Return as JSON array of objects with properties:
48
  - query (string)""")
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  class Schema:
52
  def __init__(self) -> None:
 
 
 
 
 
 
53
  self.continue_branch = genai.types.Schema(
54
  type=genai.types.Type.OBJECT,
55
  required=["decision"],
56
- properties={
57
- "decision": genai.types.Schema(type=genai.types.Type.BOOLEAN),
58
- },
59
  )
60
 
61
  self.search_query = genai.types.Schema(
62
  type=genai.types.Type.OBJECT,
63
  required=["branches"],
 
 
 
 
 
 
64
  properties={
65
- "branches": genai.types.Schema(
66
- type=genai.types.Type.ARRAY,
67
- items=genai.types.Schema(
68
- type=genai.types.Type.OBJECT,
69
- required=["query"],
70
- properties={
71
- "query": genai.types.Schema(type=genai.types.Type.STRING),
72
- },
73
- ),
74
- )
75
  },
76
  )
77
 
 
 
 
 
 
 
78
 
79
  class ResearchProgress:
80
  def __init__(self, callback):
@@ -82,90 +138,104 @@ class ResearchProgress:
82
  self.callback = callback
83
 
84
  async def update(self, progress: int, message: str):
85
- self.progress += progress
86
- if self.progress > 100:
87
- self.progress = 100
88
- if self.callback:
89
- await self.callback({"progress": self.progress, "message": message})
 
90
 
91
 
92
  class KNet:
93
- def __init__(self, scraper_instance: CrawlForAIScraper, max_depth: int = 1, max_breadth: int = 1, num_sites_per_query: int = 5):
94
  self.api_key = os.getenv("GOOGLE_API_KEY")
95
  assert self.api_key, "Google API key is required"
96
  self.scraper = scraper_instance
97
  self.logger = logging.getLogger(__name__)
98
  self.prompt = Prompt()
99
  self.schema = Schema()
 
100
 
101
  # Init Google GenAI client
102
  self.genai_client = genai.Client(api_key=self.api_key)
103
 
104
  # Parameters
105
  self.max_depth = max_depth
106
- self.max_breadth = max_breadth
107
  self.num_sites_per_query = num_sites_per_query
108
 
109
  # Global State
 
 
110
  self.ctx_researcher: list[str] = []
111
  self.ctx_manager: list[str] = []
112
  self.token_count: int = 0
113
 
114
- async def conduct_research(self, topic: str, progress_callback, max_depth: int, max_breadth: int, num_sites_per_query: int) -> dict:
115
  # Local Runtime State
116
- progress = ResearchProgress(progress_callback)
117
  self.max_depth = max_depth
118
- self.max_breadth = max_breadth
119
  self.num_sites_per_query = num_sites_per_query
120
 
121
  # Reset global state
 
 
122
  self.ctx_researcher = []
123
  self.ctx_manager = []
124
  self.token_count = 0
125
 
126
  try:
 
 
 
 
 
127
  # Generate initial search query
128
  query = self.generate_content(
129
- self.prompt.search_query.format(topic=topic, ctx_manager=json.dumps(self.ctx_manager, indent=2), n=1),
130
- schema=self.schema.search_query,
131
- )
132
- root_node = ResearchNode(query.get("branches")[0]["query"])
 
133
  to_explore = deque([(root_node, 0)]) # (node, depth) pairs
134
  explored_queries = set() # {string, string, ...}
135
 
136
- await progress.update(5, "Starting research...")
137
 
138
- while to_explore:
 
139
  current_node, current_depth = to_explore.popleft()
140
-
141
- if current_node.query in explored_queries or current_depth > self.max_depth:
142
- continue
143
-
144
- self.logger.info(f"Exploring: {current_node.query} (Depth: {current_depth})")
145
- await progress.update(5, f"Exploring: {current_node.query}")
146
-
147
- # Search and scrape
148
- current_node.data = await self.scraper.search_and_scrape(
149
- current_node.query, self.num_sites_per_query
150
- ) # node -> data = [{url:...}, {url:...}, ...]
151
- self.ctx_researcher.append(json.dumps(current_node.data, indent=2))
152
- explored_queries.add(current_node.query)
153
-
154
- # Only branch if we have data and haven't reached max depth
155
- if self._should_continue_branch(current_node, topic):
156
- if current_node.data and current_depth < self.max_depth:
157
- new_branches = self._gen_queries(current_node, topic)
158
- for branch in new_branches:
159
- to_explore.append((branch, current_depth + 1))
 
 
 
160
 
161
  # Generate final report
162
- await progress.update(30, "Generating comprehensive report...")
163
- final_report = self._generate_final_report(root_node)
164
 
165
  self.logger.info(f"Research completed. Explored {len(explored_queries)} queries across {root_node.max_depth()} levels")
166
- await progress.update(100, "Research complete!")
167
 
168
- with open("output.json", "a", encoding="utf-8") as f:
169
  json.dump(final_report, f, indent=2)
170
  return final_report
171
 
@@ -173,15 +243,30 @@ class KNet:
173
  self.logger.error("Research failed", exc_info=True)
174
  raise
175
 
176
- def _generate_final_report(self, root_node: ResearchNode, retry_count: int = 1) -> Dict[str, Any]:
177
  try:
178
- findings = "\n".join(self.ctx_manager)
179
- with open("output.json", "w", encoding="utf-8") as f:
 
180
  f.write(findings)
181
- prompt = f"""Generate a comprehensive report on the topic "{root_node.query}" based on the following research findings:
182
- {findings}
183
- """
184
- response = self.generate_content(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  # Collate multimedia content
187
  media_content = {"images": [], "videos": [], "links": [], "references": []}
@@ -203,7 +288,6 @@ class KNet:
203
  def build_tree_structure(node: ResearchNode) -> Dict:
204
  if not node:
205
  return {}
206
-
207
  sources = [d["url"] for d in node.data if d.get("url")]
208
  return {
209
  "query": node.query,
@@ -215,7 +299,7 @@ class KNet:
215
  return {
216
  "topic": root_node.query,
217
  "timestamp": datetime.now().isoformat(),
218
- "content": response,
219
  "media": media_content,
220
  "research_tree": build_tree_structure(root_node),
221
  "metadata": {
@@ -229,7 +313,7 @@ class KNet:
229
  except Exception as e:
230
  if e in ["GEMINI_RECITATION", "NO_RESPONSE"] and retry_count < 3:
231
  self.logger.error(f"Retrying final report:C:{retry_count / 3}", exc_info=True)
232
- self._generate_final_report(root_node, retry_count + 1)
233
  self.logger.error("Error generating final report", exc_info=True)
234
  raise
235
 
@@ -239,11 +323,13 @@ class KNet:
239
  return []
240
 
241
  prompt = self.prompt.search_query.format(
242
- topic=topic,
243
- ctx_manager=json.dumps(self.ctx_manager, indent=2),
244
- n=self.max_breadth,
 
 
245
  )
246
- response = self.generate_content(prompt, schema=self.schema.search_query)
247
  self.logger.info(f"Spawn branches '{node.query}':\n{json.dumps(response['branches'], indent=2)}")
248
 
249
  # Add children to current node
@@ -252,7 +338,7 @@ class KNet:
252
  # |-> child
253
  new_nodes = []
254
  for branch in response.get("branches", []):
255
- child_node = node.add_child(branch["query"])
256
  new_nodes.append(child_node)
257
 
258
  self.logger.info(f"Spawned {len(new_nodes)} new branch(es)")
@@ -261,7 +347,7 @@ class KNet:
261
  except Exception as e:
262
  if e in ["GEMINI_RECITATION", "NO_RESPONSE"] and retry_count < 3:
263
  self.logger.error(f"Retrying _gen_queries | C:{retry_count / 3}", exc_info=True)
264
- self._gen_queries(node, topic, retry_count + 1)
265
  self.logger.error("_gen_queries failed", exc_info=True)
266
  raise
267
 
@@ -273,15 +359,15 @@ class KNet:
273
  # Generate summary of key findings into the manager's context
274
  if node.data:
275
  findings = ("\n" + "-" * 10 + "Next data" + "-" * 10 + "\n").join([json.dumps(d, indent=2) for d in node.data])
276
- response = self.generate_content(f"Extract key findings from the following data related to the topic '{topic}':\n{findings}")
277
  self.ctx_manager.append(response)
278
 
279
  # Research manager takes decision to proceed or not
280
  prompt = self.prompt.continue_branch.format(
 
281
  query=node.query,
282
- depth=node.depth,
283
- path=" -> ".join(node.get_path_to_root()),
284
- findings="\n".join(self.ctx_manager),
285
  )
286
  response = self.generate_content(prompt, schema=self.schema.continue_branch)
287
  self.logger.info(f"Branch decision '{node.query}': {response['decision']}")
@@ -291,11 +377,11 @@ class KNet:
291
  except Exception as e:
292
  if e in ["GEMINI_RECITATION", "NO_RESPONSE"] and retry_count < 3:
293
  self.logger.error(f"Retrying branch decision:C:{retry_count / 3}", exc_info=True)
294
- self._should_continue_branch(node, topic, retry_count + 1)
295
  self.logger.error("Branch decision failed:", exc_info=True)
296
  raise
297
 
298
- def generate_content(self, prompt: str, schema: Dict[str, Any] = {}) -> Dict[str, Any] | str:
299
  safe = [
300
  types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=types.HarmBlockThreshold.BLOCK_NONE),
301
  types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=types.HarmBlockThreshold.BLOCK_NONE),
@@ -305,10 +391,10 @@ class KNet:
305
  ]
306
  if schema:
307
  generate_content_config = types.GenerateContentConfig(
308
- temperature=0.9, response_mime_type="application/json", safety_settings=safe, response_schema=schema
309
  )
310
  else:
311
- generate_content_config = types.GenerateContentConfig(temperature=0.9, response_mime_type="text/plain", safety_settings=safe)
312
 
313
  try:
314
  response = self.genai_client.models.generate_content(model="gemini-2.0-flash", contents=prompt, config=generate_content_config)
 
19
 
20
  class Prompt:
21
  def __init__(self) -> None:
22
+ self.research_plan = dedent("""You are an expert AI Deep Research agent, part of a Multiagent system.
23
+
24
+ User query:
25
+ "{topic}".
26
+
27
+ ---
28
+ Generate few very high level steps on which other agents can do info collection runs. Provide only data collection steps, no data identification, summarization, manipulation, selection, etc.
29
+ Return a string array of steps.""")
30
+
31
+ self.site_summary = dedent("""Extract specific verbatim key information from the following content that is related to the topic "{query}". No small talk.
32
+ Findings:
33
+ {findings}""")
34
+
35
+ self.continue_branch = dedent("""Given the current state of research, decide whether to continue exploring the current branch or not.
36
+ Global Research Plan:
37
+ {research_plan}
38
 
39
  Current Topic: {query}
40
+ Searched Queries:
41
+ {past_queries}
42
+
43
+ Findings under current topic:
44
+ {ctx_manager}
45
 
46
  Consider:
47
+ - Information saturation
48
+ - Information duplication
49
+ - Coverage of current topic
50
+ - Potential for new insights
51
 
52
+ Return only decision: true/false""")
53
 
54
+ self.search_query = dedent("""Based on the following findings on topic {vertical}, suggest new research directions.
55
+ Global Research Plan:
56
+ {research_plan}
57
+
58
+ Searched queries:
59
+ {past_queries}
60
+
61
+ Findings under current topic:
62
  {ctx_manager}
63
 
64
+ Suggest up to {n} specific google search queries that:
65
+ - Covers what has not been covered yet
66
  - Builds upon these findings
67
  - Explores different aspects
68
  - Goes deeper into important details
 
70
  Return as JSON array of objects with properties:
71
  - query (string)""")
72
 
73
+ self.report_outline = dedent("""Generate a comprehensive report outline on the user query based on the following research findings:
74
+ User query:
75
+ {topic}
76
+
77
+ Findings:
78
+ {ctx_manager}
79
+
80
+ The outline should include:
81
+ - Title
82
+ - List of h2 headings""")
83
+
84
+ self.report_fillin = dedent("""Fill in the content for the following report outline on the user query based on the following research findings:
85
+ User query:
86
+ {topic}
87
+
88
+ Findings:
89
+ {ctx_manager}
90
+
91
+ Report Outline:
92
+ {report_outline}
93
+
94
+ Current slot to fill in: (h2 heading)
95
+ {slot}
96
+ """)
97
+
98
 
99
  class Schema:
100
  def __init__(self) -> None:
101
+ self.research_plan = genai.types.Schema(
102
+ type=genai.types.Type.OBJECT,
103
+ required=["steps"],
104
+ properties={"steps": genai.types.Schema(type=genai.types.Type.ARRAY, items=genai.types.Schema(type=genai.types.Type.STRING))},
105
+ )
106
+
107
  self.continue_branch = genai.types.Schema(
108
  type=genai.types.Type.OBJECT,
109
  required=["decision"],
110
+ properties={"decision": genai.types.Schema(type=genai.types.Type.BOOLEAN)},
 
 
111
  )
112
 
113
  self.search_query = genai.types.Schema(
114
  type=genai.types.Type.OBJECT,
115
  required=["branches"],
116
+ properties={"branches": genai.types.Schema(type=genai.types.Type.ARRAY, items=genai.types.Schema(type=genai.types.Type.STRING))},
117
+ )
118
+
119
+ self.report_outline = genai.types.Schema(
120
+ type=genai.types.Type.OBJECT,
121
+ required=["title", "headings"],
122
  properties={
123
+ "title": genai.types.Schema(type=genai.types.Type.STRING),
124
+ "headings": genai.types.Schema(type=genai.types.Type.ARRAY, items=genai.types.Schema(type=genai.types.Type.STRING)),
 
 
 
 
 
 
 
 
125
  },
126
  )
127
 
128
+ self.report_fillin = genai.types.Schema(
129
+ type=genai.types.Type.OBJECT,
130
+ required=["content"],
131
+ properties={"content": genai.types.Schema(type=genai.types.Type.STRING)},
132
+ )
133
+
134
 
135
  class ResearchProgress:
136
  def __init__(self, callback):
 
138
  self.callback = callback
139
 
140
  async def update(self, progress: int, message: str):
141
+ self.progress = min(100, self.progress + progress) # max 100
142
+ await self.callback({"progress": self.progress, "message": message})
143
+
144
+ async def setter(self, progress: int, message: str):
145
+ self.progress = min(100, progress) # max 100
146
+ await self.callback({"progress": self.progress, "message": message})
147
 
148
 
149
  class KNet:
150
+ def __init__(self, scraper_instance: CrawlForAIScraper, max_depth: int = 1, num_sites_per_query: int = 5):
151
  self.api_key = os.getenv("GOOGLE_API_KEY")
152
  assert self.api_key, "Google API key is required"
153
  self.scraper = scraper_instance
154
  self.logger = logging.getLogger(__name__)
155
  self.prompt = Prompt()
156
  self.schema = Schema()
157
+ self.progress = None
158
 
159
  # Init Google GenAI client
160
  self.genai_client = genai.Client(api_key=self.api_key)
161
 
162
  # Parameters
163
  self.max_depth = max_depth
 
164
  self.num_sites_per_query = num_sites_per_query
165
 
166
  # Global State
167
+ self.research_plan: list[str] = []
168
+ self.idx_research_plan: int = 0
169
  self.ctx_researcher: list[str] = []
170
  self.ctx_manager: list[str] = []
171
  self.token_count: int = 0
172
 
173
+ async def conduct_research(self, topic: str, progress_callback, max_depth: int, num_sites_per_query: int) -> dict:
174
  # Local Runtime State
175
+ self.progress = ResearchProgress(progress_callback)
176
  self.max_depth = max_depth
 
177
  self.num_sites_per_query = num_sites_per_query
178
 
179
  # Reset global state
180
+ self.research_plan = []
181
+ self.idx_research_plan = 0
182
  self.ctx_researcher = []
183
  self.ctx_manager = []
184
  self.token_count = 0
185
 
186
  try:
187
+ # Generate research plan
188
+ await self.progress.update(0, "Generating research plan...")
189
+ self.research_plan = self.generate_content(self.prompt.research_plan.format(topic=topic), schema=self.schema.research_plan)["steps"]
190
+ self.logger.info(f"Research plan:\n{json.dumps(self.research_plan, indent=2)}")
191
+
192
  # Generate initial search query
193
  query = self.generate_content(
194
+ self.prompt.search_query.format(vertical=self.research_plan[self.idx_research_plan]), schema=self.schema.search_query
195
+ )["branches"][0]
196
+
197
+ # Initialize research tree
198
+ root_node = ResearchNode(query)
199
  to_explore = deque([(root_node, 0)]) # (node, depth) pairs
200
  explored_queries = set() # {string, string, ...}
201
 
202
+ await self.progress.update(0, "Starting research...")
203
 
204
+ # Iterate on research plan
205
+ for self.idx_research_plan, _ in enumerate(self.research_plan):
206
  current_node, current_depth = to_explore.popleft()
207
+ await self.progress.update(100 / (len(self.research_plan) + 1), f"{self.research_plan[self.idx_research_plan]}")
208
+
209
+ while to_explore:
210
+ current_node, current_depth = to_explore.popleft()
211
+ if current_depth > self.max_depth:
212
+ continue
213
+
214
+ self.logger.info(f"Exploring: {current_node.query} (depth: {current_depth})")
215
+ await self.progress.update(0, f"s_{current_node.query}")
216
+
217
+ # Search and scrape
218
+ current_node.data = await self.scraper.search_and_scrape(
219
+ current_node.query, self.num_sites_per_query
220
+ ) # node -> data = [{url:...}, {url:...}, ...]
221
+ self.ctx_researcher.append(json.dumps(current_node.data, indent=2))
222
+ explored_queries.add(current_node.query)
223
+
224
+ # Only branch if we have data and haven't reached max depth
225
+ if self._should_continue_branch(current_node, topic):
226
+ if current_node.data and current_depth < self.max_depth:
227
+ new_branches = self._gen_queries(current_node, topic)
228
+ for branch in new_branches:
229
+ to_explore.appendleft((branch, current_depth + 1))
230
 
231
  # Generate final report
232
+ await self.progress.update(100 / (len(self.research_plan) + 1), "Generating final report...")
233
+ final_report = self._generate_final_report(root_node, topic)
234
 
235
  self.logger.info(f"Research completed. Explored {len(explored_queries)} queries across {root_node.max_depth()} levels")
236
+ await self.progress.update(100, "Research complete!")
237
 
238
+ with open("output.log.json", "w", encoding="utf-8") as f:
239
  json.dump(final_report, f, indent=2)
240
  return final_report
241
 
 
243
  self.logger.error("Research failed", exc_info=True)
244
  raise
245
 
246
+ def _generate_final_report(self, root_node: ResearchNode, topic: str, retry_count: int = 1) -> Dict[str, Any]:
247
  try:
248
+ self.progress.setter(0, "Generating report...")
249
+ findings = "\n\n------\n\n".join(self.ctx_manager)
250
+ with open("ctx_manager.log.txt", "w", encoding="utf-8") as f:
251
  f.write(findings)
252
+
253
+ # Generate report outline
254
+ outline = self.generate_content(self.prompt.report_outline.format(topic=topic, ctx_manager=findings), schema=self.schema.report_outline)
255
+ self.logger.info(f"Report outline:\n{json.dumps(outline, indent=2)}")
256
+ report = []
257
+ # Fill in report outline
258
+ for i, heading in enumerate(outline["headings"]):
259
+ self.progress.update(100 / (len(outline["headings"] + 1)), "Generating report...")
260
+ content = self.generate_content(
261
+ self.prompt.report_fillin.format(
262
+ topic=topic,
263
+ ctx_manager=findings,
264
+ report_outline=["[done] " + outline["title"]] + [f"[done] {h}" for _, h in enumerate(outline["headings"]) if i < _],
265
+ slot=heading,
266
+ ),
267
+ schema=self.schema.report_fillin,
268
+ )["content"]
269
+ report.append({"heading": heading, "content": content})
270
 
271
  # Collate multimedia content
272
  media_content = {"images": [], "videos": [], "links": [], "references": []}
 
288
  def build_tree_structure(node: ResearchNode) -> Dict:
289
  if not node:
290
  return {}
 
291
  sources = [d["url"] for d in node.data if d.get("url")]
292
  return {
293
  "query": node.query,
 
299
  return {
300
  "topic": root_node.query,
301
  "timestamp": datetime.now().isoformat(),
302
+ "content": report,
303
  "media": media_content,
304
  "research_tree": build_tree_structure(root_node),
305
  "metadata": {
 
313
  except Exception as e:
314
  if e in ["GEMINI_RECITATION", "NO_RESPONSE"] and retry_count < 3:
315
  self.logger.error(f"Retrying final report:C:{retry_count / 3}", exc_info=True)
316
+ return self._generate_final_report(root_node, retry_count + 1)
317
  self.logger.error("Error generating final report", exc_info=True)
318
  raise
319
 
 
323
  return []
324
 
325
  prompt = self.prompt.search_query.format(
326
+ vertical=self.research_plan[self.idx_research_plan],
327
+ research_plan="\n".join([f"[done] {step}" for i, step in enumerate(self.research_plan) if i < self.idx_research_plan]),
328
+ past_queries="\n".join([f"[done] {query}" for query in node.get_path_to_root()[1:]]),
329
+ ctx_manager="\n\n---\n\n".join(self.ctx_manager),
330
+ n=1,
331
  )
332
+ response = self.generate_content(prompt, schema=self.schema.search_query, temp=1.5)
333
  self.logger.info(f"Spawn branches '{node.query}':\n{json.dumps(response['branches'], indent=2)}")
334
 
335
  # Add children to current node
 
338
  # |-> child
339
  new_nodes = []
340
  for branch in response.get("branches", []):
341
+ child_node = node.add_child(branch)
342
  new_nodes.append(child_node)
343
 
344
  self.logger.info(f"Spawned {len(new_nodes)} new branch(es)")
 
347
  except Exception as e:
348
  if e in ["GEMINI_RECITATION", "NO_RESPONSE"] and retry_count < 3:
349
  self.logger.error(f"Retrying _gen_queries | C:{retry_count / 3}", exc_info=True)
350
+ return self._gen_queries(node, topic, retry_count + 1)
351
  self.logger.error("_gen_queries failed", exc_info=True)
352
  raise
353
 
 
359
  # Generate summary of key findings into the manager's context
360
  if node.data:
361
  findings = ("\n" + "-" * 10 + "Next data" + "-" * 10 + "\n").join([json.dumps(d, indent=2) for d in node.data])
362
+ response = self.generate_content(self.prompt.site_summary.format(query=node.query, findings=findings), temp=0.2)
363
  self.ctx_manager.append(response)
364
 
365
  # Research manager takes decision to proceed or not
366
  prompt = self.prompt.continue_branch.format(
367
+ research_plan="\n".join([f"[done] {step}" for i, step in enumerate(self.research_plan) if i < self.idx_research_plan]),
368
  query=node.query,
369
+ past_queries="\n".join([f"[done] {query}" for query in node.get_path_to_root()[1:]]),
370
+ ctx_manager="\n\n---\n\n".join(self.ctx_manager),
 
371
  )
372
  response = self.generate_content(prompt, schema=self.schema.continue_branch)
373
  self.logger.info(f"Branch decision '{node.query}': {response['decision']}")
 
377
  except Exception as e:
378
  if e in ["GEMINI_RECITATION", "NO_RESPONSE"] and retry_count < 3:
379
  self.logger.error(f"Retrying branch decision:C:{retry_count / 3}", exc_info=True)
380
+ return self._should_continue_branch(node, topic, retry_count + 1)
381
  self.logger.error("Branch decision failed:", exc_info=True)
382
  raise
383
 
384
+ def generate_content(self, prompt: str, schema: Dict[str, Any] = {}, temp: float = 0.9) -> Dict[str, Any] | str:
385
  safe = [
386
  types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=types.HarmBlockThreshold.BLOCK_NONE),
387
  types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=types.HarmBlockThreshold.BLOCK_NONE),
 
391
  ]
392
  if schema:
393
  generate_content_config = types.GenerateContentConfig(
394
+ temperature=temp, response_mime_type="application/json", safety_settings=safe, response_schema=schema
395
  )
396
  else:
397
+ generate_content_config = types.GenerateContentConfig(temperature=temp, response_mime_type="text/plain", safety_settings=safe)
398
 
399
  try:
400
  response = self.genai_client.models.generate_content(model="gemini-2.0-flash", contents=prompt, config=generate_content_config)
backend/research_node.py CHANGED
@@ -3,9 +3,7 @@ from typing import Any, Dict, List, Optional
3
 
4
 
5
  class ResearchNode:
6
- def __init__(
7
- self, query: str, parent: Optional["ResearchNode"] = None, depth: int = 0
8
- ):
9
  self.query = query
10
  self.parent = parent
11
  self.depth = depth
@@ -18,6 +16,10 @@ class ResearchNode:
18
  return child
19
 
20
  def get_path_to_root(self) -> List[str]:
 
 
 
 
21
  path = [self.query]
22
  current = self
23
  while current.parent:
@@ -33,9 +35,7 @@ class ResearchNode:
33
  def total_children(self) -> int:
34
  if not self.children:
35
  return 0
36
- return len(self.children) + sum(
37
- [child.total_children() for child in self.children]
38
- )
39
 
40
  def get_all_data(self) -> List[Dict[str, Any]]:
41
  data = copy.deepcopy(self.data)
 
3
 
4
 
5
  class ResearchNode:
6
+ def __init__(self, query: str, parent: Optional["ResearchNode"] = None, depth: int = 0):
 
 
7
  self.query = query
8
  self.parent = parent
9
  self.depth = depth
 
16
  return child
17
 
18
  def get_path_to_root(self) -> List[str]:
19
+ """
20
+ Returns the path from this node to the root node.
21
+ List[str]: [root.query, ..., self.query]
22
+ """
23
  path = [self.query]
24
  current = self
25
  while current.parent:
 
35
  def total_children(self) -> int:
36
  if not self.children:
37
  return 0
38
+ return len(self.children) + sum([child.total_children() for child in self.children])
 
 
39
 
40
  def get_all_data(self) -> List[Dict[str, Any]]:
41
  data = copy.deepcopy(self.data)