Soham Waghmare commited on
Commit
fced70d
·
1 Parent(s): fbfef4e

feat: enhance research workflow with progress tracking and node management

Browse files
langgraph_backend/app.py CHANGED
@@ -1,24 +1,36 @@
1
  import asyncio
 
2
  import json
3
  import logging
4
  import os
5
  from datetime import datetime
6
- from typing import Any, Dict, List, Optional, TypedDict
7
 
8
  from dotenv import load_dotenv
9
  from fastapi import FastAPI, Request
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.responses import StreamingResponse
12
  from langchain_google_genai import ChatGoogleGenerativeAI
 
13
  from langgraph.graph import END, StateGraph
 
14
  from sse_starlette.sse import EventSourceResponse
15
 
16
- from prompts import RESEARCH_PLAN_PROMPT, SEARCH_QUERY_PROMPT
17
- from schema import ResearchPlan, SearchQuery
 
 
 
 
 
 
18
  from scraper import CrawlForAIScraper
19
 
20
  load_dotenv()
21
 
 
 
 
22
  logger = logging.getLogger(__name__)
23
  logging.basicConfig(level=logging.INFO)
24
 
@@ -45,94 +57,194 @@ async def health_check():
45
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=os.getenv("GOOGLE_API_KEY"))
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # --- State schema for LangGraph ---
49
  class ResearchState(TypedDict, total=False):
 
 
 
 
50
  topic: str
 
 
51
 
 
 
 
52
  research_plan: list[str]
53
  idx_research_plan: int
54
  ctx_researcher: list[str]
55
  ctx_manager: list[str]
 
56
  token_count: int
57
 
58
- scraper: CrawlForAIScraper
59
- max_depth: int
60
- num_sites_per_query: int
61
-
62
 
63
  async def research_plan_node(state: ResearchState) -> ResearchPlan:
64
- topic = state["topic"]
65
- plan = await llm.with_structured_output(ResearchPlan).ainvoke(RESEARCH_PLAN_PROMPT.format(topic=topic), temperature=1.5)
66
- if hasattr(plan, "steps"):
67
- steps = plan["steps"]
68
- logger.info(f"Research plan:\n{json.dumps(steps, indent=2)}")
69
- return steps
 
 
 
 
 
 
 
 
 
70
 
71
 
72
  async def scrape_node(state: ResearchState) -> ResearchState:
73
- topic = state["topic"]
74
- scraper = state["scraper"]
75
- max_depth = state["max_depth"]
76
- num_sites_per_query = state["num_sites_per_query"]
77
-
78
- # Generate initial search query
79
- query = llm.with_structured_output(SearchQuery).invoke(
80
- SEARCH_QUERY_PROMPT.format(
81
- vertical=state["research_plan"][state["idx_research_plan"]], topic=topic, research_plan="None", past_queries="None", ctx_manager="None", n=1
82
- ),
83
- temperature=1.5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  )
85
 
86
  # Search and scrape
87
- data = await state["scraper"].search_and_scrape(
88
- query, num_sites_per_query
89
- ) # node -> data = [{url:...}, {url:...}, ...]
90
- state["ctx_researcher"].append(json.dumps(data, indent=2))
91
- pass
92
- # TODO: Implement the scraping logic and update the state with the scraped data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
  # --- Main research logic using LangGraph ---
96
- async def run_research(topic, scraper, max_depth, num_sites_per_query):
97
  # Build the research graph
98
  graph = StateGraph(state_schema=ResearchState)
99
  graph.add_node("plan", research_plan_node)
100
  graph.add_node("scrape", scrape_node)
 
 
101
  graph.add_node("gen_report", gen_report_node)
102
 
103
  graph.add_edge("plan", "scrape")
104
- graph.add_edge("scrape", "conditional", "plan", "gen_report")
 
105
  graph.add_edge("gen_report", END)
106
  graph.set_entry_point("plan")
107
  graph = graph.compile()
108
  print(graph.get_graph().draw_mermaid())
109
 
110
- state = {
111
- "topic": topic,
112
  "scraper": scraper,
 
 
113
  "max_depth": max_depth,
114
  "num_sites_per_query": num_sites_per_query,
 
 
 
 
 
 
 
115
  }
116
- async for step in graph.astream(state):
117
- progress = step.get("progress", 0)
118
- message = step.get("message", "")
119
- yield {
120
- "event": "status",
121
- "data": json.dumps({"progress": progress, "message": message}),
122
- }
123
- yield {
124
- "event": "research_complete",
125
- "data": json.dumps(
126
- {
127
- "topic": step["topic"],
128
- "timestamp": step["timestamp"],
129
- "content": step["content"],
130
- "media": step["media"],
131
- "research_tree": step["research_tree"],
132
- "metadata": step["metadata"],
133
- }
134
- ),
135
- }
136
 
137
 
138
  @app.post("/start_research")
@@ -151,7 +263,7 @@ async def start_research(request: Request):
151
  scraper = sessions[session_id]["scraper"]
152
 
153
  async def event_generator():
154
- async for event in run_research(topic, scraper, max_depth, num_sites_per_query):
155
  yield event
156
 
157
  return EventSourceResponse(event_generator())
 
1
  import asyncio
2
+ import copy
3
  import json
4
  import logging
5
  import os
6
  from datetime import datetime
7
+ from typing import Annotated, Any, Dict, List, Literal, Optional, TypedDict
8
 
9
  from dotenv import load_dotenv
10
  from fastapi import FastAPI, Request
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi.responses import StreamingResponse
13
  from langchain_google_genai import ChatGoogleGenerativeAI
14
+ from langgraph.config import get_stream_writer
15
  from langgraph.graph import END, StateGraph
16
+ from langgraph.types import Command, StreamWriter
17
  from sse_starlette.sse import EventSourceResponse
18
 
19
+ from prompts import (
20
+ CONTINUE_BRANCH_PROMPT,
21
+ RESEARCH_PLAN_PROMPT,
22
+ SEARCH_QUERY_PROMPT,
23
+ SITE_SUMMARY_PROMPT,
24
+ )
25
+ from research_node import ResearchNode
26
+ from schema import ContinueBranch, ResearchPlan, SearchQuery
27
  from scraper import CrawlForAIScraper
28
 
29
  load_dotenv()
30
 
31
+ # Today's Date
32
+ DATE = datetime.now().strftime("%d %b, %Y")
33
+
34
  logger = logging.getLogger(__name__)
35
  logging.basicConfig(level=logging.INFO)
36
 
 
57
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=os.getenv("GOOGLE_API_KEY"))
58
 
59
 
60
+ class ResearchProgress:
61
+ def __init__(self, master_node: ResearchNode):
62
+ self.progress = 0
63
+ self.master_node = master_node
64
+
65
+ def send(self, writer: StreamWriter, progress: int, message: dict, ptype: str):
66
+ if ptype == "update":
67
+ self.progress = int(min(100, progress)) # max 100
68
+ writer({"event": "progress", "data": {"progress": self.progress, **message, "research_tree": self.master_node.build_tree_structure()}})
69
+ elif ptype == "setter":
70
+ self.progress = int(min(100, self.progress + progress)) # max 100
71
+ writer({"event": "progress", "data": {"progress": self.progress, **message, "research_tree": self.master_node.build_tree_structure()}})
72
+ elif ptype == "result":
73
+ self.progress = 100
74
+ writer({"event": "result", "data": message})
75
+
76
+
77
  # --- State schema for LangGraph ---
78
  class ResearchState(TypedDict, total=False):
79
+ scraper: CrawlForAIScraper
80
+ progress: ResearchProgress
81
+
82
+ # Paramters
83
  topic: str
84
+ max_depth: int
85
+ num_sites_per_query: int
86
 
87
+ # Global State
88
+ master_node: ResearchNode
89
+ current_node: ResearchNode
90
  research_plan: list[str]
91
  idx_research_plan: int
92
  ctx_researcher: list[str]
93
  ctx_manager: list[str]
94
+ raster_report: str
95
  token_count: int
96
 
 
 
 
 
97
 
98
  async def research_plan_node(state: ResearchState) -> ResearchPlan:
99
+ writer = get_stream_writer()
100
+
101
+ if state["idx_research_plan"] == 0:
102
+ topic = state["topic"]
103
+ plan = llm.with_structured_output(ResearchPlan).invoke(RESEARCH_PLAN_PROMPT.format(topic=topic), config={"temperature": 1.5})
104
+ if "steps" in plan:
105
+ steps = plan["steps"]
106
+
107
+ logger.info(f"Research plan:\n{json.dumps(steps, indent=2)}")
108
+ state["progress"].send(writer, 0, {"message": "Starting research..."}, ptype="setter")
109
+
110
+ return {"research_plan": steps}
111
+ else:
112
+ # TODO: Update the plan based on current information
113
+ return dict()
114
 
115
 
116
  async def scrape_node(state: ResearchState) -> ResearchState:
117
+ writer = get_stream_writer()
118
+
119
+ # Generate initial search query if first step
120
+ # TODO: Add a condition based on 1st iter or successive iters
121
+ # TODO: Wrap inference in backend.knet.generate_content
122
+ if state["idx_research_plan"] == 0:
123
+ query = (
124
+ llm.with_structured_output(SearchQuery)
125
+ .invoke(
126
+ SEARCH_QUERY_PROMPT.format(
127
+ vertical=state["research_plan"][state["idx_research_plan"]],
128
+ topic=state["topic"],
129
+ research_plan="None",
130
+ past_queries="None",
131
+ ctx_manager="None",
132
+ n=1,
133
+ ),
134
+ config={"temperature": 1.5},
135
+ )
136
+ .get("branches", [""])[0]
137
+ )
138
+ new_master = copy.deepcopy(state["master_node"])
139
+ curr_node = ResearchNode(query)
140
+ new_master.add_child(curr_node.query, node=curr_node)
141
+ else:
142
+ # TODO: Manage the Research Tree like above
143
+ query = (
144
+ llm.with_structured_output(SearchQuery)
145
+ .invoke(
146
+ SEARCH_QUERY_PROMPT.format(
147
+ vertical=state["research_plan"][state["idx_research_plan"]],
148
+ topic=state["topic"],
149
+ research_plan="\n".join([f"[done] {step}" for i, step in enumerate(state["research_plan"]) if i < state["idx_research_plan"]]),
150
+ past_queries="\n".join([f"[done] {query}" for query in state["current_node"].get_path_to_root()[1:]]),
151
+ ctx_manager="\n\n---\n\n".join(state["ctx_manager"]),
152
+ n=1,
153
+ ),
154
+ config={"temperature": 1.5},
155
+ )
156
+ .get("branches", [""])[0]
157
+ )
158
+
159
+ # Update progress
160
+ state["progress"].send(
161
+ writer, 100 / (len(state["research_plan"]) + 1), {"message": f"{state['research_plan'][state['idx_research_plan']]}"}, ptype="update"
162
  )
163
 
164
  # Search and scrape
165
+ data = await state["scraper"].search_and_scrape(query, state["num_sites_per_query"]) # node -> data = [{url:...}, {url:...}, ...]
166
+ # Add data to context
167
+ # src [1] : https://...
168
+ # content...
169
+ upd_ctx_researcher = state["ctx_researcher"] + ["\n\n---\n\n".join([f"src [{i + 1}] : {d['url']}\n{d['text']}" for i, d in enumerate(data)])]
170
+ return {"ctx_researcher": upd_ctx_researcher, "master_node": new_master, "current_node": curr_node}
171
+
172
+
173
+ async def summarize_node(state: ResearchState) -> ResearchState:
174
+ # Generate summary of key findings into the manager's context
175
+ upd_ctx_manager = state["ctx_manager"]
176
+ if state["current_node"].data:
177
+ for idx in range(0, len(state["current_node"].data), 3):
178
+ data = state["current_node"].data[idx : idx + 3]
179
+ findings = ("\n" + "-" * 10 + "Next data" + "-" * 10 + "\n").join([json.dumps(d, indent=2) for d in data])
180
+ summary = llm.invoke(SITE_SUMMARY_PROMPT.format(query=state["current_node"].query, findings=findings), config={"temperature": 0.2})
181
+ upd_ctx_manager.append(summary) if isinstance(summary, str) else None
182
+ return {"ctx_manager": upd_ctx_manager}
183
+
184
+
185
+ async def should_continue(state: ResearchState) -> Command[Literal["plan", "scrape", "gen_report"]]:
186
+ # If max depth is reached and we are at the last step of the research plan, generate report
187
+ if state["current_node"].depth > state["max_depth"] and state["idx_research_plan"] >= len(state["research_plan"]) - 1:
188
+ logger.info(f"Branch decision '{state['current_node'].query}': False")
189
+ return Command(goto="gen_report")
190
+
191
+ # If max depth is reached and we are not at the last step of the research plan, continue with the next step
192
+ elif state["current_node"].depth > state["max_depth"] and state["idx_research_plan"] < len(state["research_plan"]) - 1:
193
+ logger.info(f"Branch decision '{state['current_node'].query}': False")
194
+ return Command(goto="plan", update={"idx_research_plan": state["idx_research_plan"] + 1})
195
+
196
+ # If we have not reached max depth and not on last step of the research plan, continue with the next step
197
+ decision = llm.with_structured_output(ContinueBranch).invoke(
198
+ CONTINUE_BRANCH_PROMPT.format(
199
+ research_plan="\n".join([f"[done] {step}" for i, step in enumerate(state["research_plan"]) if i < state["idx_research_plan"]]),
200
+ query=state["current_node"].query,
201
+ past_queries="\n".join([f"[done] {query}" for query in state["current_node"].get_path_to_root()[1:]]),
202
+ ctx_manager="\n\n---\n\n".join(state["ctx_manager"]),
203
+ )
204
+ )
205
+ logger.info(f"Branch decision '{state['current_node'].query}': {decision['decision']}")
206
+ return Command(goto="scrape") if decision["decision"] else Command(goto="plan", update={"idx_research_plan": state["idx_research_plan"] + 1})
207
+
208
+
209
+ async def gen_report_node(state: ResearchState) -> ResearchState:
210
+ return
211
 
212
 
213
  # --- Main research logic using LangGraph ---
214
+ async def start_research_workflow(topic: str, scraper: CrawlForAIScraper, max_depth: int, num_sites_per_query: int):
215
  # Build the research graph
216
  graph = StateGraph(state_schema=ResearchState)
217
  graph.add_node("plan", research_plan_node)
218
  graph.add_node("scrape", scrape_node)
219
+ graph.add_node("summarize_node", summarize_node)
220
+ graph.add_node("should_continue", should_continue)
221
  graph.add_node("gen_report", gen_report_node)
222
 
223
  graph.add_edge("plan", "scrape")
224
+ graph.add_edge("scrape", "summarize_node")
225
+ graph.add_edge("summarize_node", "should_continue")
226
  graph.add_edge("gen_report", END)
227
  graph.set_entry_point("plan")
228
  graph = graph.compile()
229
  print(graph.get_graph().draw_mermaid())
230
 
231
+ master_node = ResearchNode()
232
+ state: ResearchState = {
233
  "scraper": scraper,
234
+ "progress": ResearchProgress(master_node),
235
+ "topic": topic,
236
  "max_depth": max_depth,
237
  "num_sites_per_query": num_sites_per_query,
238
+ "master_node": master_node,
239
+ "research_plan": [],
240
+ "idx_research_plan": 0,
241
+ "ctx_researcher": [],
242
+ "ctx_manager": [],
243
+ "raster_report": "",
244
+ "token_count": 0,
245
  }
246
+ async for update in graph.astream(state, stream_mode="custom"):
247
+ yield update
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
 
250
  @app.post("/start_research")
 
263
  scraper = sessions[session_id]["scraper"]
264
 
265
  async def event_generator():
266
+ async for event in start_research_workflow(topic, scraper, max_depth, num_sites_per_query):
267
  yield event
268
 
269
  return EventSourceResponse(event_generator())
langgraph_backend/research_node.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Dict, List, Optional, Self
3
+ import uuid
4
+
5
+
6
+ class ResearchNode:
7
+ def __init__(self, query: str = "_", parent: Optional[Self] = None, depth: int = 0):
8
+ self.parent = parent
9
+ self.id = str(uuid.uuid4())
10
+ self.query = query
11
+ self.depth = depth
12
+ self.children: List[ResearchNode] = []
13
+ self.data: List[Dict[str, Any]] = []
14
+
15
+ def find_node(self, node_id: str) -> Optional[Self]:
16
+ """
17
+ Returns the node with the given id.
18
+ If not found, returns None.
19
+ """
20
+ if self.id == node_id:
21
+ return self
22
+ for child in self.children:
23
+ found = child.find_node(node_id)
24
+ if found:
25
+ return found
26
+ return None
27
+
28
+ def add_child(self, query: str, node: Optional[Self] = None) -> Self:
29
+ if node:
30
+ child = node
31
+ child.parent = self
32
+ child.depth = self.depth + 1
33
+ else:
34
+ child = ResearchNode(query, parent=self, depth=self.depth + 1)
35
+ self.children.append(child)
36
+ return child
37
+
38
+ def get_path_to_root(self) -> List[str]:
39
+ """
40
+ Returns the path from this node to the root node.
41
+ List[str]: [root.query, ..., self.query]
42
+ """
43
+ path = [self.query]
44
+ current = self
45
+ while current.parent:
46
+ current = current.parent
47
+ path.append(current.query)
48
+ return list(reversed(path))
49
+
50
+ def max_depth(self) -> int:
51
+ if not self.children:
52
+ return self.depth
53
+ return max([child.max_depth() for child in self.children])
54
+
55
+ def total_children(self) -> int:
56
+ if not self.children:
57
+ return 0
58
+ return len(self.children) + sum([child.total_children() for child in self.children])
59
+
60
+ def get_all_data(self) -> List[Dict[str, Any]]:
61
+ data = copy.deepcopy(self.data)
62
+ for child in self.children:
63
+ data.extend(child.get_all_data())
64
+ return data
65
+
66
+ # Build research tree structure
67
+ def build_tree_structure(self) -> Dict:
68
+ if not self:
69
+ return {}
70
+ sources = {d["url"]: d["text"] for d in self.data if d.get("url") and d.get("text")}
71
+ return {
72
+ "query": self.query,
73
+ "depth": self.depth,
74
+ "sources": sources,
75
+ "children": [child.build_tree_structure() for child in self.children],
76
+ }
77
+
langgraph_backend/scraper.py CHANGED
@@ -7,7 +7,7 @@ from urllib.parse import quote_plus
7
 
8
  import requests
9
  from bs4 import BeautifulSoup
10
- from crawl4ai import AsyncWebCrawler, BrowserConfig, CacheMode
11
 
12
 
13
  class CrawlForAIScraper:
@@ -70,6 +70,7 @@ class CrawlForAIScraper:
70
  cache_mode=CacheMode.BYPASS,
71
  delay_before_return_html=2,
72
  scan_full_page=True,
 
73
  )
74
 
75
  soup = BeautifulSoup(result.html, "html.parser")
@@ -119,6 +120,7 @@ class CrawlForAIScraper:
119
  cache_mode=CacheMode.BYPASS,
120
  delay_before_return_html=2,
121
  scan_full_page=True,
 
122
  )
123
 
124
  soup = BeautifulSoup(result.html, "html.parser")
@@ -157,6 +159,7 @@ class CrawlForAIScraper:
157
  delay_before_return_html=2,
158
  exclude_external_images=True,
159
  page_timeout=25000,
 
160
  )
161
  scraped_sites = []
162
  for result in results:
 
7
 
8
  import requests
9
  from bs4 import BeautifulSoup
10
+ from crawl4ai import AsyncWebCrawler, BrowserConfig, CacheMode, CrawlerRunConfig
11
 
12
 
13
  class CrawlForAIScraper:
 
70
  cache_mode=CacheMode.BYPASS,
71
  delay_before_return_html=2,
72
  scan_full_page=True,
73
+ config=CrawlerRunConfig(verbose=False),
74
  )
75
 
76
  soup = BeautifulSoup(result.html, "html.parser")
 
120
  cache_mode=CacheMode.BYPASS,
121
  delay_before_return_html=2,
122
  scan_full_page=True,
123
+ config=CrawlerRunConfig(verbose=False),
124
  )
125
 
126
  soup = BeautifulSoup(result.html, "html.parser")
 
159
  delay_before_return_html=2,
160
  exclude_external_images=True,
161
  page_timeout=25000,
162
+ config=CrawlerRunConfig(verbose=False),
163
  )
164
  scraped_sites = []
165
  for result in results: