NaderAfshar
Corrected a type
73c1ece
import os
from dotenv import load_dotenv
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from typing import List, TypedDict
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_cohere import ChatCohere
from tavily import TavilyClient
from pydantic import BaseModel
import textwrap
import gradio as gr
from contextlib import ExitStack
# ========== ENVIRONMENT SETUP ==========
load_dotenv()
CO_API_KEY = os.getenv("COHERE_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
# ========== MODEL AND CLIENT SETUP ==========
cohere_model = "command-a-03-2025"
model = ChatCohere(api_key=CO_API_KEY, model=cohere_model)
tavily = TavilyClient(api_key=TAVILY_API_KEY)
# ========== PROMPTS ==========
PLAN_PROMPT = """You are an expert writer tasked with writing a high level outline of an essay. \
Write such an outline for the user provided topic. Give an outline of the essay along with any relevant notes \
or instructions for the sections."""
WRITER_PROMPT = """You are an essay assistant tasked with writing excellent 5-paragraph essays.\
Generate the best essay possible for the user's request and the initial outline. \
If the user provides critique, respond with a revised version of your previous attempts. \
Utilize all the information below as needed:\n\n------\n\n{content}"""
REFLECTION_PROMPT = """You are a teacher grading an essay submission. \
Generate critique and recommendations for the user's submission. \
Provide detailed recommendations, including requests for length, depth, style, etc."""
RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \
be used when writing the following essay. Generate a list of search queries that will gather \
any relevant information. Only generate 3 queries max."""
RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \
be used when making any requested revisions (as outlined below). \
Generate a list of search queries that will gather any relevant information. Only generate 3 queries max."""
# ========== STATE CLASS ==========
class AgentState(TypedDict):
task: str
plan: str
draft: str
critique: str
content: List[str]
revision_number: int
max_revisions: int
class Queries(BaseModel):
queries: List[str]
# ========== NODES ==========
def plan_node(state: AgentState):
messages = [SystemMessage(content=PLAN_PROMPT), HumanMessage(content=state['task'])]
response = model.invoke(messages)
return {"plan": response.content}
def research_plan_node(state: AgentState):
queries = model.with_structured_output(Queries).invoke([
SystemMessage(content=RESEARCH_PLAN_PROMPT),
HumanMessage(content=state['task'])
])
content = state.get('content', [])
for q in queries.queries:
response = tavily.search(query=q, max_results=2)
for r in response['results']:
content.append(r['content'])
return {"content": content}
def generation_node(state: AgentState):
content = "\n\n".join(state.get('content', []))
user_message = HumanMessage(content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}")
messages = [
SystemMessage(content=WRITER_PROMPT.format(content=content)),
user_message
]
response = model.invoke(messages)
return {
"draft": response.content,
"revision_number": state.get("revision_number", 1) + 1
}
def reflection_node(state: AgentState):
messages = [SystemMessage(content=REFLECTION_PROMPT), HumanMessage(content=state['draft'])]
response = model.invoke(messages)
return {"critique": response.content}
def research_critique_node(state: AgentState):
if not state.get('critique'):
return {}
queries = model.with_structured_output(Queries).invoke([
SystemMessage(content=RESEARCH_CRITIQUE_PROMPT),
HumanMessage(content=state['critique'])
])
content = state['content'] or []
for q in queries.queries:
response = tavily.search(query=q, max_results=2)
for r in response['results']:
content.append(r['content'])
return {"content": content}
def should_continue(state: AgentState):
if state["revision_number"] > state["max_revisions"]:
return END
return "reflect"
# ========== GRAPH DEFINITION ==========
builder = StateGraph(AgentState)
builder.add_node("planner", plan_node)
builder.add_node("generate", generation_node)
builder.add_node("reflect", reflection_node)
builder.add_node("research_plan", research_plan_node)
builder.add_node("research_critique", research_critique_node)
builder.set_entry_point("planner")
builder.add_conditional_edges("generate", should_continue, {END: END, "reflect": "reflect"})
builder.add_edge("planner", "research_plan")
builder.add_edge("research_plan", "generate")
builder.add_edge("reflect", "research_critique")
builder.add_edge("research_critique", "generate")
stack = ExitStack()
checkpointer = stack.enter_context(SqliteSaver.from_conn_string(":memory:"))
graph = builder.compile(checkpointer=checkpointer)
# ========== INITIAL STATE FUNCTION ==========
def create_initial_state(overrides: dict = None) -> dict:
state = {
"task": "",
"plan": "",
"draft": "",
"critique": "",
"content": [],
"revision_number": 0,
"max_revisions": 3
}
if overrides:
state.update(overrides)
return state
# ========== GRAPH EXECUTION ==========
def run_graph_with_topic(topic, max_revisions=2):
thread = {"configurable": {"thread_id": "1"}}
state = create_initial_state({
"task": topic,
"max_revisions": max_revisions,
"revision_number": 1
})
output_log = ""
final_draft = ""
for s in graph.stream(state, thread):
for k, v in s.items():
output_log += f"\n--- {k.upper()} ---\n"
if isinstance(v, dict):
for subkey, value in v.items():
if isinstance(value, str):
output_log += f"{subkey}:\n{textwrap.fill(value, width=100)}\n\n"
if subkey == "draft":
final_draft = value
elif isinstance(value, list):
output_log += f"{subkey}:\n"
for i, item in enumerate(value, 1):
output_log += f" [{i}] {textwrap.fill(str(item), width=100)}\n"
else:
output_log += f"{subkey}: {value}\n"
else:
output_log += textwrap.fill(str(v), width=100) + "\n"
# Stream intermediate log update
yield {
output_log_box: gr.update(value=output_log),
final_draft_box: gr.update(value="") # Clear draft until end
}
# Final result
yield {
output_log_box: gr.update(value=output_log),
final_draft_box: gr.update(value=final_draft)
}
# ========== GRADIO INTERFACE ==========
with gr.Blocks() as demo:
gr.Markdown("## ✍️ Research Agent \n Developed by: Nader Afshar \n\n Enter a topic and generate a research report.")
with gr.Row():
topic_input = gr.Textbox(label="Research Topic", placeholder="e.g., What is the impact of AI on jobs?")
max_rev_input = gr.Slider(1, 5, value=2, step=1, label="Max Revisions")
run_button = gr.Button("Generate Report")
with gr.Row():
output_log_box = gr.Textbox(label="Agent Process Log", lines=20, interactive=False)
final_draft_box = gr.Textbox(label="Final Draft", lines=10, interactive=False)
# This is the corrected streaming connection
run_button.click(fn=run_graph_with_topic,
inputs=[topic_input, max_rev_input],
outputs=[output_log_box, final_draft_box])
demo.launch()