Spaces:
Running
Running
File size: 7,936 Bytes
0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 0169c4a 6b3edaa 73c1ece 6b3edaa 25108f8 6b3edaa 25108f8 6b3edaa 25108f8 6b3edaa 0169c4a 6b3edaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import os
from dotenv import load_dotenv
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from typing import List, TypedDict
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_cohere import ChatCohere
from tavily import TavilyClient
from pydantic import BaseModel
import textwrap
import gradio as gr
from contextlib import ExitStack
# ========== ENVIRONMENT SETUP ==========
load_dotenv()
CO_API_KEY = os.getenv("COHERE_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
# ========== MODEL AND CLIENT SETUP ==========
cohere_model = "command-a-03-2025"
model = ChatCohere(api_key=CO_API_KEY, model=cohere_model)
tavily = TavilyClient(api_key=TAVILY_API_KEY)
# ========== PROMPTS ==========
PLAN_PROMPT = """You are an expert writer tasked with writing a high level outline of an essay. \
Write such an outline for the user provided topic. Give an outline of the essay along with any relevant notes \
or instructions for the sections."""
WRITER_PROMPT = """You are an essay assistant tasked with writing excellent 5-paragraph essays.\
Generate the best essay possible for the user's request and the initial outline. \
If the user provides critique, respond with a revised version of your previous attempts. \
Utilize all the information below as needed:\n\n------\n\n{content}"""
REFLECTION_PROMPT = """You are a teacher grading an essay submission. \
Generate critique and recommendations for the user's submission. \
Provide detailed recommendations, including requests for length, depth, style, etc."""
RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \
be used when writing the following essay. Generate a list of search queries that will gather \
any relevant information. Only generate 3 queries max."""
RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \
be used when making any requested revisions (as outlined below). \
Generate a list of search queries that will gather any relevant information. Only generate 3 queries max."""
# ========== STATE CLASS ==========
class AgentState(TypedDict):
task: str
plan: str
draft: str
critique: str
content: List[str]
revision_number: int
max_revisions: int
class Queries(BaseModel):
queries: List[str]
# ========== NODES ==========
def plan_node(state: AgentState):
messages = [SystemMessage(content=PLAN_PROMPT), HumanMessage(content=state['task'])]
response = model.invoke(messages)
return {"plan": response.content}
def research_plan_node(state: AgentState):
queries = model.with_structured_output(Queries).invoke([
SystemMessage(content=RESEARCH_PLAN_PROMPT),
HumanMessage(content=state['task'])
])
content = state.get('content', [])
for q in queries.queries:
response = tavily.search(query=q, max_results=2)
for r in response['results']:
content.append(r['content'])
return {"content": content}
def generation_node(state: AgentState):
content = "\n\n".join(state.get('content', []))
user_message = HumanMessage(content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}")
messages = [
SystemMessage(content=WRITER_PROMPT.format(content=content)),
user_message
]
response = model.invoke(messages)
return {
"draft": response.content,
"revision_number": state.get("revision_number", 1) + 1
}
def reflection_node(state: AgentState):
messages = [SystemMessage(content=REFLECTION_PROMPT), HumanMessage(content=state['draft'])]
response = model.invoke(messages)
return {"critique": response.content}
def research_critique_node(state: AgentState):
if not state.get('critique'):
return {}
queries = model.with_structured_output(Queries).invoke([
SystemMessage(content=RESEARCH_CRITIQUE_PROMPT),
HumanMessage(content=state['critique'])
])
content = state['content'] or []
for q in queries.queries:
response = tavily.search(query=q, max_results=2)
for r in response['results']:
content.append(r['content'])
return {"content": content}
def should_continue(state: AgentState):
if state["revision_number"] > state["max_revisions"]:
return END
return "reflect"
# ========== GRAPH DEFINITION ==========
builder = StateGraph(AgentState)
builder.add_node("planner", plan_node)
builder.add_node("generate", generation_node)
builder.add_node("reflect", reflection_node)
builder.add_node("research_plan", research_plan_node)
builder.add_node("research_critique", research_critique_node)
builder.set_entry_point("planner")
builder.add_conditional_edges("generate", should_continue, {END: END, "reflect": "reflect"})
builder.add_edge("planner", "research_plan")
builder.add_edge("research_plan", "generate")
builder.add_edge("reflect", "research_critique")
builder.add_edge("research_critique", "generate")
stack = ExitStack()
checkpointer = stack.enter_context(SqliteSaver.from_conn_string(":memory:"))
graph = builder.compile(checkpointer=checkpointer)
# ========== INITIAL STATE FUNCTION ==========
def create_initial_state(overrides: dict = None) -> dict:
state = {
"task": "",
"plan": "",
"draft": "",
"critique": "",
"content": [],
"revision_number": 0,
"max_revisions": 3
}
if overrides:
state.update(overrides)
return state
# ========== GRAPH EXECUTION ==========
def run_graph_with_topic(topic, max_revisions=2):
thread = {"configurable": {"thread_id": "1"}}
state = create_initial_state({
"task": topic,
"max_revisions": max_revisions,
"revision_number": 1
})
output_log = ""
final_draft = ""
for s in graph.stream(state, thread):
for k, v in s.items():
output_log += f"\n--- {k.upper()} ---\n"
if isinstance(v, dict):
for subkey, value in v.items():
if isinstance(value, str):
output_log += f"{subkey}:\n{textwrap.fill(value, width=100)}\n\n"
if subkey == "draft":
final_draft = value
elif isinstance(value, list):
output_log += f"{subkey}:\n"
for i, item in enumerate(value, 1):
output_log += f" [{i}] {textwrap.fill(str(item), width=100)}\n"
else:
output_log += f"{subkey}: {value}\n"
else:
output_log += textwrap.fill(str(v), width=100) + "\n"
# Stream intermediate log update
yield {
output_log_box: gr.update(value=output_log),
final_draft_box: gr.update(value="") # Clear draft until end
}
# Final result
yield {
output_log_box: gr.update(value=output_log),
final_draft_box: gr.update(value=final_draft)
}
# ========== GRADIO INTERFACE ==========
with gr.Blocks() as demo:
gr.Markdown("## ✍️ Research Agent \n Developed by: Nader Afshar \n\n Enter a topic and generate a research report.")
with gr.Row():
topic_input = gr.Textbox(label="Research Topic", placeholder="e.g., What is the impact of AI on jobs?")
max_rev_input = gr.Slider(1, 5, value=2, step=1, label="Max Revisions")
run_button = gr.Button("Generate Report")
with gr.Row():
output_log_box = gr.Textbox(label="Agent Process Log", lines=20, interactive=False)
final_draft_box = gr.Textbox(label="Final Draft", lines=10, interactive=False)
# This is the corrected streaming connection
run_button.click(fn=run_graph_with_topic,
inputs=[topic_input, max_rev_input],
outputs=[output_log_box, final_draft_box])
demo.launch()
|