File size: 7,936 Bytes
0169c4a
 
 
 
6b3edaa
 
0169c4a
 
 
6b3edaa
 
 
0169c4a
6b3edaa
 
0169c4a
 
 
6b3edaa
0169c4a
6b3edaa
 
0169c4a
6b3edaa
0169c4a
 
 
 
 
 
 
6b3edaa
0169c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
6b3edaa
 
 
 
 
 
 
 
 
0169c4a
 
6b3edaa
 
0169c4a
 
6b3edaa
0169c4a
6b3edaa
0169c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b3edaa
0169c4a
 
 
6b3edaa
0169c4a
 
 
 
 
 
 
 
6b3edaa
0169c4a
 
 
6b3edaa
0169c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b3edaa
0169c4a
 
 
 
 
6b3edaa
0169c4a
 
 
 
 
 
 
 
6b3edaa
0169c4a
 
 
 
 
 
 
 
 
 
6b3edaa
0169c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b3edaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0169c4a
 
6b3edaa
 
73c1ece
6b3edaa
 
25108f8
6b3edaa
 
25108f8
6b3edaa
 
 
25108f8
6b3edaa
 
 
 
 
0169c4a
6b3edaa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os
from dotenv import load_dotenv
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from typing import List, TypedDict
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_cohere import ChatCohere
from tavily import TavilyClient
from pydantic import BaseModel
import textwrap
import gradio as gr
from contextlib import ExitStack

# ========== ENVIRONMENT SETUP ==========
load_dotenv()
CO_API_KEY = os.getenv("COHERE_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")

# ========== MODEL AND CLIENT SETUP ==========
cohere_model = "command-a-03-2025"
model = ChatCohere(api_key=CO_API_KEY, model=cohere_model)
tavily = TavilyClient(api_key=TAVILY_API_KEY)

# ========== PROMPTS ==========
PLAN_PROMPT = """You are an expert writer tasked with writing a high level outline of an essay. \
Write such an outline for the user provided topic. Give an outline of the essay along with any relevant notes \
or instructions for the sections."""

WRITER_PROMPT = """You are an essay assistant tasked with writing excellent 5-paragraph essays.\
Generate the best essay possible for the user's request and the initial outline. \
If the user provides critique, respond with a revised version of your previous attempts. \
Utilize all the information below as needed:\n\n------\n\n{content}"""

REFLECTION_PROMPT = """You are a teacher grading an essay submission. \
Generate critique and recommendations for the user's submission. \
Provide detailed recommendations, including requests for length, depth, style, etc."""

RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \
be used when writing the following essay. Generate a list of search queries that will gather \
any relevant information. Only generate 3 queries max."""

RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \
be used when making any requested revisions (as outlined below). \
Generate a list of search queries that will gather any relevant information. Only generate 3 queries max."""


# ========== STATE CLASS ==========
class AgentState(TypedDict):
    task: str
    plan: str
    draft: str
    critique: str
    content: List[str]
    revision_number: int
    max_revisions: int


class Queries(BaseModel):
    queries: List[str]


# ========== NODES ==========
def plan_node(state: AgentState):
    messages = [SystemMessage(content=PLAN_PROMPT), HumanMessage(content=state['task'])]
    response = model.invoke(messages)
    return {"plan": response.content}


def research_plan_node(state: AgentState):
    queries = model.with_structured_output(Queries).invoke([
        SystemMessage(content=RESEARCH_PLAN_PROMPT),
        HumanMessage(content=state['task'])
    ])
    content = state.get('content', [])
    for q in queries.queries:
        response = tavily.search(query=q, max_results=2)
        for r in response['results']:
            content.append(r['content'])
    return {"content": content}


def generation_node(state: AgentState):
    content = "\n\n".join(state.get('content', []))
    user_message = HumanMessage(content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}")
    messages = [
        SystemMessage(content=WRITER_PROMPT.format(content=content)),
        user_message
    ]
    response = model.invoke(messages)
    return {
        "draft": response.content,
        "revision_number": state.get("revision_number", 1) + 1
    }


def reflection_node(state: AgentState):
    messages = [SystemMessage(content=REFLECTION_PROMPT), HumanMessage(content=state['draft'])]
    response = model.invoke(messages)
    return {"critique": response.content}


def research_critique_node(state: AgentState):
    if not state.get('critique'):
        return {}
    queries = model.with_structured_output(Queries).invoke([
        SystemMessage(content=RESEARCH_CRITIQUE_PROMPT),
        HumanMessage(content=state['critique'])
    ])
    content = state['content'] or []
    for q in queries.queries:
        response = tavily.search(query=q, max_results=2)
        for r in response['results']:
            content.append(r['content'])
    return {"content": content}


def should_continue(state: AgentState):
    if state["revision_number"] > state["max_revisions"]:
        return END
    return "reflect"


# ========== GRAPH DEFINITION ==========
builder = StateGraph(AgentState)
builder.add_node("planner", plan_node)
builder.add_node("generate", generation_node)
builder.add_node("reflect", reflection_node)
builder.add_node("research_plan", research_plan_node)
builder.add_node("research_critique", research_critique_node)

builder.set_entry_point("planner")
builder.add_conditional_edges("generate", should_continue, {END: END, "reflect": "reflect"})
builder.add_edge("planner", "research_plan")
builder.add_edge("research_plan", "generate")
builder.add_edge("reflect", "research_critique")
builder.add_edge("research_critique", "generate")

stack = ExitStack()
checkpointer = stack.enter_context(SqliteSaver.from_conn_string(":memory:"))
graph = builder.compile(checkpointer=checkpointer)


# ========== INITIAL STATE FUNCTION ==========
def create_initial_state(overrides: dict = None) -> dict:
    state = {
        "task": "",
        "plan": "",
        "draft": "",
        "critique": "",
        "content": [],
        "revision_number": 0,
        "max_revisions": 3
    }
    if overrides:
        state.update(overrides)
    return state


# ========== GRAPH EXECUTION ==========
def run_graph_with_topic(topic, max_revisions=2):
    thread = {"configurable": {"thread_id": "1"}}
    state = create_initial_state({
        "task": topic,
        "max_revisions": max_revisions,
        "revision_number": 1
    })

    output_log = ""
    final_draft = ""

    for s in graph.stream(state, thread):
        for k, v in s.items():
            output_log += f"\n--- {k.upper()} ---\n"
            if isinstance(v, dict):
                for subkey, value in v.items():
                    if isinstance(value, str):
                        output_log += f"{subkey}:\n{textwrap.fill(value, width=100)}\n\n"
                        if subkey == "draft":
                            final_draft = value
                    elif isinstance(value, list):
                        output_log += f"{subkey}:\n"
                        for i, item in enumerate(value, 1):
                            output_log += f"  [{i}] {textwrap.fill(str(item), width=100)}\n"
                    else:
                        output_log += f"{subkey}: {value}\n"
            else:
                output_log += textwrap.fill(str(v), width=100) + "\n"

        # Stream intermediate log update
        yield {
            output_log_box: gr.update(value=output_log),
            final_draft_box: gr.update(value="")  # Clear draft until end
        }

    # Final result
    yield {
        output_log_box: gr.update(value=output_log),
        final_draft_box: gr.update(value=final_draft)
    }


# ========== GRADIO INTERFACE ==========
with gr.Blocks() as demo:
    gr.Markdown("## ✍️ Research Agent \n Developed by: Nader Afshar \n\n Enter a topic and generate a research report.")

    with gr.Row():
        topic_input = gr.Textbox(label="Research Topic", placeholder="e.g., What is the impact of AI on jobs?")
        max_rev_input = gr.Slider(1, 5, value=2, step=1, label="Max Revisions")

    run_button = gr.Button("Generate Report")

    with gr.Row():
        output_log_box = gr.Textbox(label="Agent Process Log", lines=20, interactive=False)
        final_draft_box = gr.Textbox(label="Final Draft", lines=10, interactive=False)

    # This is the corrected streaming connection
    run_button.click(fn=run_graph_with_topic,
                     inputs=[topic_input, max_rev_input],
                     outputs=[output_log_box, final_draft_box])

demo.launch()