import os from dotenv import load_dotenv from langgraph.graph import StateGraph, END from langgraph.checkpoint.sqlite import SqliteSaver from typing import List, TypedDict from langchain_core.messages import SystemMessage, HumanMessage from langchain_cohere import ChatCohere from tavily import TavilyClient from pydantic import BaseModel import textwrap import gradio as gr from contextlib import ExitStack # ========== ENVIRONMENT SETUP ========== load_dotenv() CO_API_KEY = os.getenv("COHERE_API_KEY") TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") # ========== MODEL AND CLIENT SETUP ========== cohere_model = "command-a-03-2025" model = ChatCohere(api_key=CO_API_KEY, model=cohere_model) tavily = TavilyClient(api_key=TAVILY_API_KEY) # ========== PROMPTS ========== PLAN_PROMPT = """You are an expert writer tasked with writing a high level outline of an essay. \ Write such an outline for the user provided topic. Give an outline of the essay along with any relevant notes \ or instructions for the sections.""" WRITER_PROMPT = """You are an essay assistant tasked with writing excellent 5-paragraph essays.\ Generate the best essay possible for the user's request and the initial outline. \ If the user provides critique, respond with a revised version of your previous attempts. \ Utilize all the information below as needed:\n\n------\n\n{content}""" REFLECTION_PROMPT = """You are a teacher grading an essay submission. \ Generate critique and recommendations for the user's submission. \ Provide detailed recommendations, including requests for length, depth, style, etc.""" RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \ be used when writing the following essay. Generate a list of search queries that will gather \ any relevant information. Only generate 3 queries max.""" RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \ be used when making any requested revisions (as outlined below). \ Generate a list of search queries that will gather any relevant information. Only generate 3 queries max.""" # ========== STATE CLASS ========== class AgentState(TypedDict): task: str plan: str draft: str critique: str content: List[str] revision_number: int max_revisions: int class Queries(BaseModel): queries: List[str] # ========== NODES ========== def plan_node(state: AgentState): messages = [SystemMessage(content=PLAN_PROMPT), HumanMessage(content=state['task'])] response = model.invoke(messages) return {"plan": response.content} def research_plan_node(state: AgentState): queries = model.with_structured_output(Queries).invoke([ SystemMessage(content=RESEARCH_PLAN_PROMPT), HumanMessage(content=state['task']) ]) content = state.get('content', []) for q in queries.queries: response = tavily.search(query=q, max_results=2) for r in response['results']: content.append(r['content']) return {"content": content} def generation_node(state: AgentState): content = "\n\n".join(state.get('content', [])) user_message = HumanMessage(content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}") messages = [ SystemMessage(content=WRITER_PROMPT.format(content=content)), user_message ] response = model.invoke(messages) return { "draft": response.content, "revision_number": state.get("revision_number", 1) + 1 } def reflection_node(state: AgentState): messages = [SystemMessage(content=REFLECTION_PROMPT), HumanMessage(content=state['draft'])] response = model.invoke(messages) return {"critique": response.content} def research_critique_node(state: AgentState): if not state.get('critique'): return {} queries = model.with_structured_output(Queries).invoke([ SystemMessage(content=RESEARCH_CRITIQUE_PROMPT), HumanMessage(content=state['critique']) ]) content = state['content'] or [] for q in queries.queries: response = tavily.search(query=q, max_results=2) for r in response['results']: content.append(r['content']) return {"content": content} def should_continue(state: AgentState): if state["revision_number"] > state["max_revisions"]: return END return "reflect" # ========== GRAPH DEFINITION ========== builder = StateGraph(AgentState) builder.add_node("planner", plan_node) builder.add_node("generate", generation_node) builder.add_node("reflect", reflection_node) builder.add_node("research_plan", research_plan_node) builder.add_node("research_critique", research_critique_node) builder.set_entry_point("planner") builder.add_conditional_edges("generate", should_continue, {END: END, "reflect": "reflect"}) builder.add_edge("planner", "research_plan") builder.add_edge("research_plan", "generate") builder.add_edge("reflect", "research_critique") builder.add_edge("research_critique", "generate") stack = ExitStack() checkpointer = stack.enter_context(SqliteSaver.from_conn_string(":memory:")) graph = builder.compile(checkpointer=checkpointer) # ========== INITIAL STATE FUNCTION ========== def create_initial_state(overrides: dict = None) -> dict: state = { "task": "", "plan": "", "draft": "", "critique": "", "content": [], "revision_number": 0, "max_revisions": 3 } if overrides: state.update(overrides) return state # ========== GRAPH EXECUTION ========== def run_graph_with_topic(topic, max_revisions=2): thread = {"configurable": {"thread_id": "1"}} state = create_initial_state({ "task": topic, "max_revisions": max_revisions, "revision_number": 1 }) output_log = "" final_draft = "" for s in graph.stream(state, thread): for k, v in s.items(): output_log += f"\n--- {k.upper()} ---\n" if isinstance(v, dict): for subkey, value in v.items(): if isinstance(value, str): output_log += f"{subkey}:\n{textwrap.fill(value, width=100)}\n\n" if subkey == "draft": final_draft = value elif isinstance(value, list): output_log += f"{subkey}:\n" for i, item in enumerate(value, 1): output_log += f" [{i}] {textwrap.fill(str(item), width=100)}\n" else: output_log += f"{subkey}: {value}\n" else: output_log += textwrap.fill(str(v), width=100) + "\n" # Stream intermediate log update yield { output_log_box: gr.update(value=output_log), final_draft_box: gr.update(value="") # Clear draft until end } # Final result yield { output_log_box: gr.update(value=output_log), final_draft_box: gr.update(value=final_draft) } # ========== GRADIO INTERFACE ========== with gr.Blocks() as demo: gr.Markdown("## ✍️ Research Agent \n Developed by: Nader Afshar \n\n Enter a topic and generate a research report.") with gr.Row(): topic_input = gr.Textbox(label="Research Topic", placeholder="e.g., What is the impact of AI on jobs?") max_rev_input = gr.Slider(1, 5, value=2, step=1, label="Max Revisions") run_button = gr.Button("Generate Report") with gr.Row(): output_log_box = gr.Textbox(label="Agent Process Log", lines=20, interactive=False) final_draft_box = gr.Textbox(label="Final Draft", lines=10, interactive=False) # This is the corrected streaming connection run_button.click(fn=run_graph_with_topic, inputs=[topic_input, max_rev_input], outputs=[output_log_box, final_draft_box]) demo.launch()