Soham Waghmare commited on
Commit
70f0982
·
1 Parent(s): fced70d

feat: working langgraph DeepResearch

Browse files
langgraph_backend/app.py CHANGED
@@ -18,12 +18,20 @@ from sse_starlette.sse import EventSourceResponse
18
 
19
  from prompts import (
20
  CONTINUE_BRANCH_PROMPT,
 
 
21
  RESEARCH_PLAN_PROMPT,
22
  SEARCH_QUERY_PROMPT,
23
  SITE_SUMMARY_PROMPT,
24
  )
25
  from research_node import ResearchNode
26
- from schema import ContinueBranch, ResearchPlan, SearchQuery
 
 
 
 
 
 
27
  from scraper import CrawlForAIScraper
28
 
29
  load_dotenv()
@@ -58,17 +66,20 @@ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=os.getenv(
58
 
59
 
60
  class ResearchProgress:
61
- def __init__(self, master_node: ResearchNode):
62
  self.progress = 0
63
- self.master_node = master_node
64
 
65
- def send(self, writer: StreamWriter, progress: int, message: dict, ptype: str):
66
  if ptype == "update":
67
- self.progress = int(min(100, progress)) # max 100
68
- writer({"event": "progress", "data": {"progress": self.progress, **message, "research_tree": self.master_node.build_tree_structure()}})
69
- elif ptype == "setter":
70
  self.progress = int(min(100, self.progress + progress)) # max 100
71
- writer({"event": "progress", "data": {"progress": self.progress, **message, "research_tree": self.master_node.build_tree_structure()}})
 
 
 
 
 
 
 
72
  elif ptype == "result":
73
  self.progress = 100
74
  writer({"event": "result", "data": message})
@@ -95,74 +106,51 @@ class ResearchState(TypedDict, total=False):
95
  token_count: int
96
 
97
 
98
- async def research_plan_node(state: ResearchState) -> ResearchPlan:
99
  writer = get_stream_writer()
100
 
101
- if state["idx_research_plan"] == 0:
102
  topic = state["topic"]
103
  plan = llm.with_structured_output(ResearchPlan).invoke(RESEARCH_PLAN_PROMPT.format(topic=topic), config={"temperature": 1.5})
104
  if "steps" in plan:
105
  steps = plan["steps"]
106
 
107
  logger.info(f"Research plan:\n{json.dumps(steps, indent=2)}")
108
- state["progress"].send(writer, 0, {"message": "Starting research..."}, ptype="setter")
109
 
110
  return {"research_plan": steps}
111
- else:
112
- # TODO: Update the plan based on current information
113
- return dict()
114
 
115
 
116
  async def scrape_node(state: ResearchState) -> ResearchState:
117
- writer = get_stream_writer()
118
-
119
- # Generate initial search query if first step
120
- # TODO: Add a condition based on 1st iter or successive iters
121
- # TODO: Wrap inference in backend.knet.generate_content
122
- if state["idx_research_plan"] == 0:
123
- query = (
124
- llm.with_structured_output(SearchQuery)
125
- .invoke(
126
- SEARCH_QUERY_PROMPT.format(
127
- vertical=state["research_plan"][state["idx_research_plan"]],
128
- topic=state["topic"],
129
- research_plan="None",
130
- past_queries="None",
131
- ctx_manager="None",
132
- n=1,
133
- ),
134
- config={"temperature": 1.5},
135
- )
136
- .get("branches", [""])[0]
137
  )
138
- new_master = copy.deepcopy(state["master_node"])
139
- curr_node = ResearchNode(query)
 
 
 
 
 
140
  new_master.add_child(curr_node.query, node=curr_node)
 
141
  else:
142
- # TODO: Manage the Research Tree like above
143
- query = (
144
- llm.with_structured_output(SearchQuery)
145
- .invoke(
146
- SEARCH_QUERY_PROMPT.format(
147
- vertical=state["research_plan"][state["idx_research_plan"]],
148
- topic=state["topic"],
149
- research_plan="\n".join([f"[done] {step}" for i, step in enumerate(state["research_plan"]) if i < state["idx_research_plan"]]),
150
- past_queries="\n".join([f"[done] {query}" for query in state["current_node"].get_path_to_root()[1:]]),
151
- ctx_manager="\n\n---\n\n".join(state["ctx_manager"]),
152
- n=1,
153
- ),
154
- config={"temperature": 1.5},
155
- )
156
- .get("branches", [""])[0]
157
- )
158
 
159
- # Update progress
160
- state["progress"].send(
161
- writer, 100 / (len(state["research_plan"]) + 1), {"message": f"{state['research_plan'][state['idx_research_plan']]}"}, ptype="update"
162
- )
163
-
164
- # Search and scrape
165
- data = await state["scraper"].search_and_scrape(query, state["num_sites_per_query"]) # node -> data = [{url:...}, {url:...}, ...]
166
  # Add data to context
167
  # src [1] : https://...
168
  # content...
@@ -175,23 +163,43 @@ async def summarize_node(state: ResearchState) -> ResearchState:
175
  upd_ctx_manager = state["ctx_manager"]
176
  if state["current_node"].data:
177
  for idx in range(0, len(state["current_node"].data), 3):
178
- data = state["current_node"].data[idx : idx + 3]
179
- findings = ("\n" + "-" * 10 + "Next data" + "-" * 10 + "\n").join([json.dumps(d, indent=2) for d in data])
180
- summary = llm.invoke(SITE_SUMMARY_PROMPT.format(query=state["current_node"].query, findings=findings), config={"temperature": 0.2})
181
- upd_ctx_manager.append(summary) if isinstance(summary, str) else None
182
  return {"ctx_manager": upd_ctx_manager}
183
 
184
 
185
- async def should_continue(state: ResearchState) -> Command[Literal["plan", "scrape", "gen_report"]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  # If max depth is reached and we are at the last step of the research plan, generate report
187
- if state["current_node"].depth > state["max_depth"] and state["idx_research_plan"] >= len(state["research_plan"]) - 1:
188
  logger.info(f"Branch decision '{state['current_node'].query}': False")
189
  return Command(goto="gen_report")
190
 
191
  # If max depth is reached and we are not at the last step of the research plan, continue with the next step
192
- elif state["current_node"].depth > state["max_depth"] and state["idx_research_plan"] < len(state["research_plan"]) - 1:
193
  logger.info(f"Branch decision '{state['current_node'].query}': False")
194
- return Command(goto="plan", update={"idx_research_plan": state["idx_research_plan"] + 1})
195
 
196
  # If we have not reached max depth and not on last step of the research plan, continue with the next step
197
  decision = llm.with_structured_output(ContinueBranch).invoke(
@@ -203,11 +211,84 @@ async def should_continue(state: ResearchState) -> Command[Literal["plan", "scra
203
  )
204
  )
205
  logger.info(f"Branch decision '{state['current_node'].query}': {decision['decision']}")
206
- return Command(goto="scrape") if decision["decision"] else Command(goto="plan", update={"idx_research_plan": state["idx_research_plan"] + 1})
207
 
208
 
209
  async def gen_report_node(state: ResearchState) -> ResearchState:
210
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
  # --- Main research logic using LangGraph ---
@@ -216,26 +297,29 @@ async def start_research_workflow(topic: str, scraper: CrawlForAIScraper, max_de
216
  graph = StateGraph(state_schema=ResearchState)
217
  graph.add_node("plan", research_plan_node)
218
  graph.add_node("scrape", scrape_node)
219
- graph.add_node("summarize_node", summarize_node)
220
- graph.add_node("should_continue", should_continue)
221
  graph.add_node("gen_report", gen_report_node)
222
 
223
  graph.add_edge("plan", "scrape")
224
- graph.add_edge("scrape", "summarize_node")
225
- graph.add_edge("summarize_node", "should_continue")
226
  graph.add_edge("gen_report", END)
227
  graph.set_entry_point("plan")
228
  graph = graph.compile()
229
  print(graph.get_graph().draw_mermaid())
230
 
231
  master_node = ResearchNode()
 
 
232
  state: ResearchState = {
233
  "scraper": scraper,
234
- "progress": ResearchProgress(master_node),
235
  "topic": topic,
236
  "max_depth": max_depth,
237
  "num_sites_per_query": num_sites_per_query,
238
  "master_node": master_node,
 
239
  "research_plan": [],
240
  "idx_research_plan": 0,
241
  "ctx_researcher": [],
@@ -243,7 +327,7 @@ async def start_research_workflow(topic: str, scraper: CrawlForAIScraper, max_de
243
  "raster_report": "",
244
  "token_count": 0,
245
  }
246
- async for update in graph.astream(state, stream_mode="custom"):
247
  yield update
248
 
249
 
@@ -280,8 +364,6 @@ async def abort_research(request: Request):
280
  return {"status": "aborted"}
281
 
282
 
283
- # Add more endpoints as needed for test, etc.
284
-
285
  if __name__ == "__main__":
286
  logger.info("Starting KnowledgeNet server...")
287
  import uvicorn
 
18
 
19
  from prompts import (
20
  CONTINUE_BRANCH_PROMPT,
21
+ REPORT_FILLIN_PROMPT,
22
+ REPORT_OUTLINE_PROMPT,
23
  RESEARCH_PLAN_PROMPT,
24
  SEARCH_QUERY_PROMPT,
25
  SITE_SUMMARY_PROMPT,
26
  )
27
  from research_node import ResearchNode
28
+ from schema import (
29
+ ContinueBranch,
30
+ ReportFillin,
31
+ ReportOutline,
32
+ ResearchPlan,
33
+ SearchQuery,
34
+ )
35
  from scraper import CrawlForAIScraper
36
 
37
  load_dotenv()
 
66
 
67
 
68
  class ResearchProgress:
69
+ def __init__(self): # Removed master_node from __init__
70
  self.progress = 0
 
71
 
72
+ def send(self, writer: StreamWriter, progress: int, message: dict, ptype: str, master_node_for_send: ResearchNode = None):
73
  if ptype == "update":
 
 
 
74
  self.progress = int(min(100, self.progress + progress)) # max 100
75
+ writer(
76
+ {"event": "progress", "data": {"progress": self.progress, **message, "research_tree": master_node_for_send.build_tree_structure()}}
77
+ )
78
+ elif ptype == "setter":
79
+ self.progress = int(min(100, progress)) # max 100
80
+ writer(
81
+ {"event": "progress", "data": {"progress": self.progress, **message, "research_tree": master_node_for_send.build_tree_structure()}}
82
+ )
83
  elif ptype == "result":
84
  self.progress = 100
85
  writer({"event": "result", "data": message})
 
106
  token_count: int
107
 
108
 
109
+ async def research_plan_node(state: ResearchState) -> ResearchState:
110
  writer = get_stream_writer()
111
 
112
+ if len(state["research_plan"]) == 0:
113
  topic = state["topic"]
114
  plan = llm.with_structured_output(ResearchPlan).invoke(RESEARCH_PLAN_PROMPT.format(topic=topic), config={"temperature": 1.5})
115
  if "steps" in plan:
116
  steps = plan["steps"]
117
 
118
  logger.info(f"Research plan:\n{json.dumps(steps, indent=2)}")
119
+ state["progress"].send(writer, 0, {"message": "Starting research..."}, ptype="setter", master_node_for_send=state["master_node"])
120
 
121
  return {"research_plan": steps}
 
 
 
122
 
123
 
124
  async def scrape_node(state: ResearchState) -> ResearchState:
125
+ # TODO: idx_research_plan index error here
126
+ query = (
127
+ llm.with_structured_output(SearchQuery)
128
+ .invoke(
129
+ SEARCH_QUERY_PROMPT.format(
130
+ vertical=state["research_plan"][state["idx_research_plan"]],
131
+ topic=state["topic"],
132
+ research_plan="\n".join([f"[done] {step}" for i, step in enumerate(state["research_plan"]) if i < state["idx_research_plan"]]),
133
+ past_queries="\n".join([f"[done] {query}" for query in state["current_node"].get_path_to_root()[1:]]),
134
+ ctx_manager="\n\n---\n\n".join(state["ctx_manager"]),
135
+ n=1,
136
+ ),
137
+ config={"temperature": 1.5},
 
 
 
 
 
 
 
138
  )
139
+ .get("branches", [""])[0]
140
+ )
141
+
142
+ new_master = ResearchNode.deep_copy_tree(state["master_node"])
143
+ curr_node = ResearchNode(query)
144
+ # Add a new vertical node
145
+ if state["current_node"].depth >= state["max_depth"]:
146
  new_master.add_child(curr_node.query, node=curr_node)
147
+ # Add a branch to the current node
148
  else:
149
+ old_curr_node = new_master.find_node(state["current_node"].id)
150
+ old_curr_node.add_child(curr_node.query, node=curr_node)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ data = await state["scraper"].search_and_scrape(query, state["num_sites_per_query"])
153
+ curr_node.data = data
 
 
 
 
 
154
  # Add data to context
155
  # src [1] : https://...
156
  # content...
 
163
  upd_ctx_manager = state["ctx_manager"]
164
  if state["current_node"].data:
165
  for idx in range(0, len(state["current_node"].data), 3):
166
+ summary = llm.invoke(
167
+ SITE_SUMMARY_PROMPT.format(query=state["current_node"].query, findings=state["ctx_researcher"][-1]), config={"temperature": 0.2}
168
+ ).text()
169
+ upd_ctx_manager.append(summary)
170
  return {"ctx_manager": upd_ctx_manager}
171
 
172
 
173
+ async def should_continue_node(state: ResearchState) -> Command[Literal["plan", "scrape", "gen_report"]]:
174
+ print( # TODO: Remove this print statement
175
+ json.dumps(
176
+ {
177
+ "current_node": {"query": state["current_node"].query, "depth": state["current_node"].depth},
178
+ "max_depth": state["max_depth"],
179
+ "idx_research_plan": state["idx_research_plan"],
180
+ },
181
+ indent=2,
182
+ )
183
+ )
184
+ writer = get_stream_writer()
185
+ target_progress_for_step = (state["idx_research_plan"] + 1) * (100.0 / (len(state["research_plan"]) if state["research_plan"] else 1))
186
+ state["progress"].send(
187
+ writer,
188
+ target_progress_for_step,
189
+ {"message": f"{state['research_plan'][state['idx_research_plan']]}"},
190
+ ptype="update",
191
+ master_node_for_send=state["master_node"],
192
+ )
193
+
194
  # If max depth is reached and we are at the last step of the research plan, generate report
195
+ if state["current_node"].depth >= state["max_depth"] and state["idx_research_plan"] >= len(state["research_plan"]) - 1:
196
  logger.info(f"Branch decision '{state['current_node'].query}': False")
197
  return Command(goto="gen_report")
198
 
199
  # If max depth is reached and we are not at the last step of the research plan, continue with the next step
200
+ if state["current_node"].depth >= state["max_depth"] and state["idx_research_plan"] < len(state["research_plan"]) - 1:
201
  logger.info(f"Branch decision '{state['current_node'].query}': False")
202
+ return Command(goto="plan", update={"idx_research_plan": state["idx_research_plan"] + 1, "current_node": state["master_node"]})
203
 
204
  # If we have not reached max depth and not on last step of the research plan, continue with the next step
205
  decision = llm.with_structured_output(ContinueBranch).invoke(
 
211
  )
212
  )
213
  logger.info(f"Branch decision '{state['current_node'].query}': {decision['decision']}")
214
+ return Command(goto="scrape", update={"idx_research_plan": state["idx_research_plan"] + 0 if decision["decision"] else 1})
215
 
216
 
217
  async def gen_report_node(state: ResearchState) -> ResearchState:
218
+ writer = get_stream_writer()
219
+ state["progress"].send(writer, 0, {"message": "Generating report..."}, ptype="setter", master_node_for_send=state["master_node"])
220
+ findings = "\n\n------\n\n".join(state["ctx_manager"])
221
+ with open("ctx_manager.log.txt", "w", encoding="utf-8") as f:
222
+ f.write(findings)
223
+
224
+ # Generate report outline
225
+ outline = llm.with_structured_output(ReportOutline).invoke(REPORT_OUTLINE_PROMPT.format(topic=state["topic"], ctx_manager=findings))
226
+ logger.info(f"Report outline:\n{json.dumps(outline, indent=2)}")
227
+ report = []
228
+ raster_report = f"# {outline['title']}\n\n"
229
+
230
+ # Fill in report outline
231
+ for i, heading in enumerate(outline["headings"]):
232
+ state["progress"].send(
233
+ writer,
234
+ 100 / (len(outline["headings"]) + 1),
235
+ {"message": "Generating report..."},
236
+ ptype="update",
237
+ master_node_for_send=state["master_node"],
238
+ )
239
+ content = llm.with_structured_output(ReportFillin).invoke(
240
+ REPORT_FILLIN_PROMPT.format(
241
+ topic=state["topic"],
242
+ ctx_manager=findings,
243
+ report_progress=raster_report,
244
+ report_outline=["[done] " + outline["title"]] + [f"[done] {h}" for _, h in enumerate(outline["headings"]) if i < _],
245
+ slot=heading,
246
+ ),
247
+ )["content"]
248
+ # Remove heading if LLM put it there regardless
249
+ idx_heading = content.find(heading)
250
+ if idx_heading != -1:
251
+ content = content[idx_heading + len(heading) :].strip()
252
+ report.append({"heading": heading, "content": content})
253
+ raster_report += f"\n\n## {heading}\n\n{content}"
254
+
255
+ # Collate multimedia content
256
+ media_content = {"images": [], "videos": [], "links": []}
257
+ all_sources_data = state["master_node"].get_all_data()
258
+ for data in all_sources_data:
259
+ if data.get("images"):
260
+ media_content["images"].extend(data["images"])
261
+ if data.get("videos"):
262
+ media_content["videos"].extend(data["videos"])
263
+ if data.get("links"):
264
+ media_content["links"].extend([{"url": link["href"], "text": link["text"]} for link in data["links"]])
265
+ # Dedupe
266
+ media_content["images"] = list(set(media_content["images"]))
267
+ media_content["videos"] = list(set(media_content["videos"]))
268
+ media_content["links"] = list({json.dumps(d, sort_keys=True) for d in media_content["links"]})
269
+ media_content["links"] = [json.loads(d) for d in media_content["links"]]
270
+
271
+ result = {
272
+ "topic": state["topic"],
273
+ "timestamp": datetime.now().isoformat(),
274
+ "content": raster_report,
275
+ "media": media_content,
276
+ "research_tree": state["master_node"].build_tree_structure(),
277
+ "metadata": {
278
+ "total_queries": state["master_node"].total_children(),
279
+ "total_sources": len(all_sources_data),
280
+ "max_depth_reached": state["master_node"].max_depth(),
281
+ "total_tokens": state["token_count"],
282
+ },
283
+ }
284
+ with open("output.log.json", "w", encoding="utf-8") as f:
285
+ json.dump(result, f, indent=2)
286
+ state["progress"].send(
287
+ writer,
288
+ 100,
289
+ result,
290
+ ptype="result",
291
+ )
292
 
293
 
294
  # --- Main research logic using LangGraph ---
 
297
  graph = StateGraph(state_schema=ResearchState)
298
  graph.add_node("plan", research_plan_node)
299
  graph.add_node("scrape", scrape_node)
300
+ graph.add_node("summarize", summarize_node)
301
+ graph.add_node("should_continue", should_continue_node)
302
  graph.add_node("gen_report", gen_report_node)
303
 
304
  graph.add_edge("plan", "scrape")
305
+ graph.add_edge("scrape", "summarize")
306
+ graph.add_edge("summarize", "should_continue")
307
  graph.add_edge("gen_report", END)
308
  graph.set_entry_point("plan")
309
  graph = graph.compile()
310
  print(graph.get_graph().draw_mermaid())
311
 
312
  master_node = ResearchNode()
313
+ initial_current_node = master_node
314
+
315
  state: ResearchState = {
316
  "scraper": scraper,
317
+ "progress": ResearchProgress(),
318
  "topic": topic,
319
  "max_depth": max_depth,
320
  "num_sites_per_query": num_sites_per_query,
321
  "master_node": master_node,
322
+ "current_node": initial_current_node,
323
  "research_plan": [],
324
  "idx_research_plan": 0,
325
  "ctx_researcher": [],
 
327
  "raster_report": "",
328
  "token_count": 0,
329
  }
330
+ async for update in graph.astream(state, {"recursion_limit": 1000}, stream_mode="custom"):
331
  yield update
332
 
333
 
 
364
  return {"status": "aborted"}
365
 
366
 
 
 
367
  if __name__ == "__main__":
368
  logger.info("Starting KnowledgeNet server...")
369
  import uvicorn
langgraph_backend/research_node.py CHANGED
@@ -75,3 +75,19 @@ class ResearchNode:
75
  "children": [child.build_tree_structure() for child in self.children],
76
  }
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  "children": [child.build_tree_structure() for child in self.children],
76
  }
77
 
78
+ # Return deep copy with node pointers | Isolated function
79
+ def deep_copy_tree(root: Optional[Self] = None) -> Self:
80
+ """
81
+ Returns a deep copy of the tree starting from this node.
82
+ """
83
+ if root is None:
84
+ return None
85
+ new_node = ResearchNode(root.query, depth=root.depth)
86
+ new_node.id = root.id
87
+ new_node.data = copy.deepcopy(root.data)
88
+ for child in root.children:
89
+ new_child = ResearchNode.deep_copy_tree(child)
90
+ new_child.parent = new_node
91
+ new_node.children.append(new_child)
92
+ return new_node
93
+
langgraph_backend/scraper.py CHANGED
@@ -188,8 +188,8 @@ class CrawlForAIScraper:
188
  all_videos = list(set(all_videos + media_videos))
189
 
190
  data = {
191
- "url": result.url,
192
- "text": result.markdown,
193
  "images": all_images,
194
  "videos": all_videos,
195
  "links": self._extract_links(result.links["external"]),
 
188
  all_videos = list(set(all_videos + media_videos))
189
 
190
  data = {
191
+ "url": str(result.url),
192
+ "text": str(result.markdown),
193
  "images": all_images,
194
  "videos": all_videos,
195
  "links": self._extract_links(result.links["external"]),