Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.sqlite import SqliteSaver | |
| from typing import List, TypedDict | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_cohere import ChatCohere | |
| from tavily import TavilyClient | |
| from pydantic import BaseModel | |
| import textwrap | |
| import gradio as gr | |
| from contextlib import ExitStack | |
| # ========== ENVIRONMENT SETUP ========== | |
| load_dotenv() | |
| CO_API_KEY = os.getenv("COHERE_API_KEY") | |
| TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") | |
| # ========== MODEL AND CLIENT SETUP ========== | |
| cohere_model = "command-a-03-2025" | |
| model = ChatCohere(api_key=CO_API_KEY, model=cohere_model) | |
| tavily = TavilyClient(api_key=TAVILY_API_KEY) | |
| # ========== PROMPTS ========== | |
| PLAN_PROMPT = """You are an expert writer tasked with writing a high level outline of an essay. \ | |
| Write such an outline for the user provided topic. Give an outline of the essay along with any relevant notes \ | |
| or instructions for the sections.""" | |
| WRITER_PROMPT = """You are an essay assistant tasked with writing excellent 5-paragraph essays.\ | |
| Generate the best essay possible for the user's request and the initial outline. \ | |
| If the user provides critique, respond with a revised version of your previous attempts. \ | |
| Utilize all the information below as needed:\n\n------\n\n{content}""" | |
| REFLECTION_PROMPT = """You are a teacher grading an essay submission. \ | |
| Generate critique and recommendations for the user's submission. \ | |
| Provide detailed recommendations, including requests for length, depth, style, etc.""" | |
| RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \ | |
| be used when writing the following essay. Generate a list of search queries that will gather \ | |
| any relevant information. Only generate 3 queries max.""" | |
| RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \ | |
| be used when making any requested revisions (as outlined below). \ | |
| Generate a list of search queries that will gather any relevant information. Only generate 3 queries max.""" | |
| # ========== STATE CLASS ========== | |
| class AgentState(TypedDict): | |
| task: str | |
| plan: str | |
| draft: str | |
| critique: str | |
| content: List[str] | |
| revision_number: int | |
| max_revisions: int | |
| class Queries(BaseModel): | |
| queries: List[str] | |
| # ========== NODES ========== | |
| def plan_node(state: AgentState): | |
| messages = [SystemMessage(content=PLAN_PROMPT), HumanMessage(content=state['task'])] | |
| response = model.invoke(messages) | |
| return {"plan": response.content} | |
| def research_plan_node(state: AgentState): | |
| queries = model.with_structured_output(Queries).invoke([ | |
| SystemMessage(content=RESEARCH_PLAN_PROMPT), | |
| HumanMessage(content=state['task']) | |
| ]) | |
| content = state.get('content', []) | |
| for q in queries.queries: | |
| response = tavily.search(query=q, max_results=2) | |
| for r in response['results']: | |
| content.append(r['content']) | |
| return {"content": content} | |
| def generation_node(state: AgentState): | |
| content = "\n\n".join(state.get('content', [])) | |
| user_message = HumanMessage(content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}") | |
| messages = [ | |
| SystemMessage(content=WRITER_PROMPT.format(content=content)), | |
| user_message | |
| ] | |
| response = model.invoke(messages) | |
| return { | |
| "draft": response.content, | |
| "revision_number": state.get("revision_number", 1) + 1 | |
| } | |
| def reflection_node(state: AgentState): | |
| messages = [SystemMessage(content=REFLECTION_PROMPT), HumanMessage(content=state['draft'])] | |
| response = model.invoke(messages) | |
| return {"critique": response.content} | |
| def research_critique_node(state: AgentState): | |
| if not state.get('critique'): | |
| return {} | |
| queries = model.with_structured_output(Queries).invoke([ | |
| SystemMessage(content=RESEARCH_CRITIQUE_PROMPT), | |
| HumanMessage(content=state['critique']) | |
| ]) | |
| content = state['content'] or [] | |
| for q in queries.queries: | |
| response = tavily.search(query=q, max_results=2) | |
| for r in response['results']: | |
| content.append(r['content']) | |
| return {"content": content} | |
| def should_continue(state: AgentState): | |
| if state["revision_number"] > state["max_revisions"]: | |
| return END | |
| return "reflect" | |
| # ========== GRAPH DEFINITION ========== | |
| builder = StateGraph(AgentState) | |
| builder.add_node("planner", plan_node) | |
| builder.add_node("generate", generation_node) | |
| builder.add_node("reflect", reflection_node) | |
| builder.add_node("research_plan", research_plan_node) | |
| builder.add_node("research_critique", research_critique_node) | |
| builder.set_entry_point("planner") | |
| builder.add_conditional_edges("generate", should_continue, {END: END, "reflect": "reflect"}) | |
| builder.add_edge("planner", "research_plan") | |
| builder.add_edge("research_plan", "generate") | |
| builder.add_edge("reflect", "research_critique") | |
| builder.add_edge("research_critique", "generate") | |
| stack = ExitStack() | |
| checkpointer = stack.enter_context(SqliteSaver.from_conn_string(":memory:")) | |
| graph = builder.compile(checkpointer=checkpointer) | |
| # ========== INITIAL STATE FUNCTION ========== | |
| def create_initial_state(overrides: dict = None) -> dict: | |
| state = { | |
| "task": "", | |
| "plan": "", | |
| "draft": "", | |
| "critique": "", | |
| "content": [], | |
| "revision_number": 0, | |
| "max_revisions": 3 | |
| } | |
| if overrides: | |
| state.update(overrides) | |
| return state | |
| # ========== GRAPH EXECUTION ========== | |
| def run_graph_with_topic(topic, max_revisions=2): | |
| thread = {"configurable": {"thread_id": "1"}} | |
| state = create_initial_state({ | |
| "task": topic, | |
| "max_revisions": max_revisions, | |
| "revision_number": 1 | |
| }) | |
| output_log = "" | |
| final_draft = "" | |
| for s in graph.stream(state, thread): | |
| for k, v in s.items(): | |
| output_log += f"\n--- {k.upper()} ---\n" | |
| if isinstance(v, dict): | |
| for subkey, value in v.items(): | |
| if isinstance(value, str): | |
| output_log += f"{subkey}:\n{textwrap.fill(value, width=100)}\n\n" | |
| if subkey == "draft": | |
| final_draft = value | |
| elif isinstance(value, list): | |
| output_log += f"{subkey}:\n" | |
| for i, item in enumerate(value, 1): | |
| output_log += f" [{i}] {textwrap.fill(str(item), width=100)}\n" | |
| else: | |
| output_log += f"{subkey}: {value}\n" | |
| else: | |
| output_log += textwrap.fill(str(v), width=100) + "\n" | |
| # Stream intermediate log update | |
| yield { | |
| output_log_box: gr.update(value=output_log), | |
| final_draft_box: gr.update(value="") # Clear draft until end | |
| } | |
| # Final result | |
| yield { | |
| output_log_box: gr.update(value=output_log), | |
| final_draft_box: gr.update(value=final_draft) | |
| } | |
| # ========== GRADIO INTERFACE ========== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## ✍️ Research Agent \n Developed by: Nader Afshar \n\n Enter a topic and generate a research report.") | |
| with gr.Row(): | |
| topic_input = gr.Textbox(label="Research Topic", placeholder="e.g., What is the impact of AI on jobs?") | |
| max_rev_input = gr.Slider(1, 5, value=2, step=1, label="Max Revisions") | |
| run_button = gr.Button("Generate Report") | |
| with gr.Row(): | |
| output_log_box = gr.Textbox(label="Agent Process Log", lines=20, interactive=False) | |
| final_draft_box = gr.Textbox(label="Final Draft", lines=10, interactive=False) | |
| # This is the corrected streaming connection | |
| run_button.click(fn=run_graph_with_topic, | |
| inputs=[topic_input, max_rev_input], | |
| outputs=[output_log_box, final_draft_box]) | |
| demo.launch() | |