import asyncio import copy import json import logging import os from datetime import datetime from typing import Annotated, Any, Dict, List, Literal, Optional, TypedDict from dotenv import load_dotenv from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from langchain_google_genai import ChatGoogleGenerativeAI from langgraph.config import get_stream_writer from langgraph.graph import END, StateGraph from langgraph.types import Command, StreamWriter from sse_starlette.sse import EventSourceResponse from prompts import ( CONTINUE_BRANCH_PROMPT, REPORT_FILLIN_PROMPT, REPORT_OUTLINE_PROMPT, RESEARCH_PLAN_PROMPT, SEARCH_QUERY_PROMPT, SITE_SUMMARY_PROMPT, ) from research_node import ResearchNode from schema import ( ContinueBranch, ReportFillin, ReportOutline, ResearchPlan, SearchQuery, ) from scraper import CrawlForAIScraper from agent_tools import invoke_agent load_dotenv() # Today's Date DATE = datetime.now().strftime("%d %b, %Y") logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) app = FastAPI() CORS_ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", ",").split(",") app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOWED_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Session management (in-memory for now) sessions: Dict[str, Dict[str, Any]] = {} @app.get("/health") async def health_check(): return {"status": "ok"} # --- LangChain LLM setup (Gemini, correct usage) --- llm = ChatGoogleGenerativeAI(model="gemini-flash-latest", google_api_key=os.getenv("GOOGLE_API_KEY")) class ResearchProgress: def __init__(self): # Removed master_node from __init__ self.progress = 0 def send(self, writer: StreamWriter, progress: int, message: dict, ptype: str, master_node_for_send: ResearchNode = None): if ptype == "update": self.progress = int(min(100, self.progress + progress)) # max 100 writer( {"event": "progress", "data": {"progress": self.progress, **message, "research_tree": master_node_for_send.build_tree_structure()}} ) elif ptype == "setter": self.progress = int(min(100, progress)) # max 100 writer( {"event": "progress", "data": {"progress": self.progress, **message, "research_tree": master_node_for_send.build_tree_structure()}} ) elif ptype == "result": self.progress = 100 writer({"event": "result", "data": message}) # --- State schema for LangGraph --- class ResearchState(TypedDict, total=False): scraper: CrawlForAIScraper progress: ResearchProgress # Paramters topic: str max_depth: int num_sites_per_query: int # Global State master_node: ResearchNode current_node: ResearchNode research_plan: list[str] idx_research_plan: int ctx_researcher: list[str] ctx_manager: list[str] raster_report: str token_count: int async def research_plan_node(state: ResearchState) -> ResearchState: writer = get_stream_writer() if len(state["research_plan"]) == 0: topic = state["topic"] plan = llm.with_structured_output(ResearchPlan).invoke(RESEARCH_PLAN_PROMPT.format(topic=topic), config={"temperature": 1.5}) if "steps" in plan: steps = plan["steps"] logger.info(f"Research plan:\n{json.dumps(steps, indent=2)}") state["progress"].send(writer, 0, {"message": "Starting research..."}, ptype="setter", master_node_for_send=state["master_node"]) return {"research_plan": steps} async def scrape_node(state: ResearchState) -> ResearchState: # TODO: idx_research_plan index error here query = ( llm.with_structured_output(SearchQuery) .invoke( SEARCH_QUERY_PROMPT.format( vertical=state["research_plan"][state["idx_research_plan"]], topic=state["topic"], research_plan="\n".join([f"[done] {step}" for i, step in enumerate(state["research_plan"]) if i < state["idx_research_plan"]]), past_queries="\n".join([f"[done] {query}" for query in state["current_node"].get_path_to_root()[1:]]), ctx_manager="\n\n---\n\n".join(state["ctx_manager"]), n=1, ), config={"temperature": 1.5}, ) .get("branches", [""])[0] ) new_master = ResearchNode.deep_copy_tree(state["master_node"]) curr_node = ResearchNode(query) # Add a new vertical node if state["current_node"].depth >= state["max_depth"]: new_master.add_child(curr_node.query, node=curr_node) # Add a branch to the current node else: old_curr_node = new_master.find_node(state["current_node"].id) old_curr_node.add_child(curr_node.query, node=curr_node) data = await state["scraper"].search_and_scrape(query, state["num_sites_per_query"]) curr_node.data = data # Add data to context # src [1] : https://... # content... upd_ctx_researcher = state["ctx_researcher"] + ["\n\n---\n\n".join([f"src [{i + 1}] : {d['url']}\n{d['text']}" for i, d in enumerate(data)])] return {"ctx_researcher": upd_ctx_researcher, "master_node": new_master, "current_node": curr_node} async def summarize_node(state: ResearchState) -> ResearchState: # Generate summary of key findings into the manager's context upd_ctx_manager = state["ctx_manager"] if state["current_node"].data: for idx in range(0, len(state["current_node"].data), 3): summary = llm.invoke( SITE_SUMMARY_PROMPT.format(query=state["current_node"].query, findings=state["ctx_researcher"][-1]), config={"temperature": 0.2} ).text() upd_ctx_manager.append(summary) return {"ctx_manager": upd_ctx_manager} async def should_continue_node(state: ResearchState) -> Command[Literal["plan", "scrape", "gen_report"]]: print( # TODO: Remove this print statement json.dumps( { "current_node": {"query": state["current_node"].query, "depth": state["current_node"].depth}, "max_depth": state["max_depth"], "idx_research_plan": state["idx_research_plan"], }, indent=2, ) ) writer = get_stream_writer() target_progress_for_step = (state["idx_research_plan"] + 1) * (100.0 / (len(state["research_plan"]) if state["research_plan"] else 1)) state["progress"].send( writer, target_progress_for_step, {"message": f"{state['research_plan'][state['idx_research_plan']]}"}, ptype="update", master_node_for_send=state["master_node"], ) # If max depth is reached and we are at the last step of the research plan, generate report if state["current_node"].depth >= state["max_depth"] and state["idx_research_plan"] >= len(state["research_plan"]) - 1: logger.info(f"Branch decision '{state['current_node'].query}': False") return Command(goto="gen_report") # If max depth is reached and we are not at the last step of the research plan, continue with the next step if state["current_node"].depth >= state["max_depth"] and state["idx_research_plan"] < len(state["research_plan"]) - 1: logger.info(f"Branch decision '{state['current_node'].query}': False") return Command(goto="plan", update={"idx_research_plan": state["idx_research_plan"] + 1, "current_node": state["master_node"]}) # If we have not reached max depth and not on last step of the research plan, continue with the next step decision = llm.with_structured_output(ContinueBranch).invoke( CONTINUE_BRANCH_PROMPT.format( research_plan="\n".join([f"[done] {step}" for i, step in enumerate(state["research_plan"]) if i < state["idx_research_plan"]]), query=state["current_node"].query, past_queries="\n".join([f"[done] {query}" for query in state["current_node"].get_path_to_root()[1:]]), ctx_manager="\n\n---\n\n".join(state["ctx_manager"]), ) ) logger.info(f"Branch decision '{state['current_node'].query}': {decision['decision']}") return Command(goto="scrape", update={"idx_research_plan": state["idx_research_plan"] + 0 if decision["decision"] else 1}) async def gen_report_node(state: ResearchState) -> ResearchState: writer = get_stream_writer() state["progress"].send(writer, 0, {"message": "Generating report..."}, ptype="setter", master_node_for_send=state["master_node"]) findings = "\n\n------\n\n".join(state["ctx_manager"]) with open("ctx_manager.log.txt", "w", encoding="utf-8") as f: f.write(findings) # Generate report outline outline = llm.with_structured_output(ReportOutline).invoke(REPORT_OUTLINE_PROMPT.format(topic=state["topic"], ctx_manager=findings)) logger.info(f"Report outline:\n{json.dumps(outline, indent=2)}") report = [] raster_report = f"# {outline['title']}\n\n" # Fill in report outline for i, heading in enumerate(outline["headings"]): state["progress"].send( writer, 100 / (len(outline["headings"]) + 1), {"message": "Generating report..."}, ptype="update", master_node_for_send=state["master_node"], ) content = llm.with_structured_output(ReportFillin).invoke( REPORT_FILLIN_PROMPT.format( topic=state["topic"], ctx_manager=findings, report_progress=raster_report, report_outline=["[done] " + outline["title"]] + [f"[done] {h}" for _, h in enumerate(outline["headings"]) if i < _], slot=heading, ), )["content"] # Remove heading if LLM put it there regardless idx_heading = content.find(heading) if idx_heading != -1: content = content[idx_heading + len(heading) :].strip() report.append({"heading": heading, "content": content}) raster_report += f"\n\n## {heading}\n\n{content}" # Collate multimedia content media_content = {"images": [], "videos": [], "links": []} all_sources_data = state["master_node"].get_all_data() for data in all_sources_data: if data.get("images"): media_content["images"].extend(data["images"]) if data.get("videos"): media_content["videos"].extend(data["videos"]) if data.get("links"): media_content["links"].extend([{"url": link["href"], "text": link["text"]} for link in data["links"]]) # Dedupe media_content["images"] = list(set(media_content["images"])) media_content["videos"] = list(set(media_content["videos"])) media_content["links"] = list({json.dumps(d, sort_keys=True) for d in media_content["links"]}) media_content["links"] = [json.loads(d) for d in media_content["links"]] result = { "topic": state["topic"], "timestamp": datetime.now().isoformat(), "content": raster_report, "media": media_content, "research_tree": state["master_node"].build_tree_structure(), "metadata": { "total_queries": state["master_node"].total_children(), "total_sources": len(all_sources_data), "max_depth_reached": state["master_node"].max_depth(), "total_tokens": state["token_count"], }, } with open("output.log.json", "w", encoding="utf-8") as f: json.dump(result, f, indent=2) state["progress"].send( writer, 100, result, ptype="result", ) # --- Main research logic using LangGraph --- async def start_research_workflow(topic: str, scraper: CrawlForAIScraper, max_depth: int, num_sites_per_query: int): # Build the research graph graph = StateGraph(state_schema=ResearchState) graph.add_node("plan", research_plan_node) graph.add_node("scrape", scrape_node) graph.add_node("summarize", summarize_node) graph.add_node("should_continue", should_continue_node) graph.add_node("gen_report", gen_report_node) graph.add_edge("plan", "scrape") graph.add_edge("scrape", "summarize") graph.add_edge("summarize", "should_continue") graph.add_edge("gen_report", END) graph.set_entry_point("plan") graph = graph.compile() print(graph.get_graph().draw_mermaid()) master_node = ResearchNode() initial_current_node = master_node state: ResearchState = { "scraper": scraper, "progress": ResearchProgress(), "topic": topic, "max_depth": max_depth, "num_sites_per_query": num_sites_per_query, "master_node": master_node, "current_node": initial_current_node, "research_plan": [], "idx_research_plan": 0, "ctx_researcher": [], "ctx_manager": [], "raster_report": "", "token_count": 0, } async for update in graph.astream(state, {"recursion_limit": 1000}, stream_mode="custom"): yield update @app.post("/start_research") async def start_research(request: Request): data = await request.json() topic = data.get("topic", "").strip() max_depth = int(data.get("max_depth", 1)) num_sites_per_query = int(data.get("num_sites_per_query", 5)) session_id = data.get("session_id") or os.urandom(8).hex() if session_id not in sessions: scraper = CrawlForAIScraper() await scraper.start() sessions[session_id] = {"scraper": scraper} else: scraper = sessions[session_id]["scraper"] async def event_generator(): async for event in start_research_workflow(topic, scraper, max_depth, num_sites_per_query): yield event return EventSourceResponse(event_generator()) @app.post("/chat") async def chat(request: Request): data = await request.json() message = data.get("message") thread_id = data.get("thread_id") create_report = data.get("create_report", False) async def event_generator(): async for event in invoke_agent(message, thread_id, create_report=create_report): # Format the event as SSE (Server-Sent Events) event_data = json.dumps(event) yield f"data: {event_data}\n\n" return StreamingResponse( event_generator(), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/event-stream", }, ) @app.post("/abort_research") async def abort_research(request: Request): data = await request.json() session_id = data.get("session_id") if session_id in sessions: scraper = sessions[session_id]["scraper"] await scraper.close() del sessions[session_id] return {"status": "aborted"} if __name__ == "__main__": logger.info("Starting KnowledgeNet server...") import uvicorn uvicorn.run(app, host="127.0.0.1", port=5000)