NaderAfshar commited on
Commit
6b3edaa
·
1 Parent(s): 0169c4a

updated with gradio interface

Browse files
Files changed (1) hide show
  1. app.py +93 -125
app.py CHANGED
@@ -2,94 +2,66 @@ import os
2
  from dotenv import load_dotenv
3
  from langgraph.graph import StateGraph, END
4
  from langgraph.checkpoint.sqlite import SqliteSaver
5
- from typing import List, TypedDict, Annotated
6
- from langchain_core.messages import (AnyMessage,
7
- SystemMessage,
8
- HumanMessage,
9
- ToolMessage,
10
- AIMessage )
11
  from langchain_cohere import ChatCohere
12
  from tavily import TavilyClient
13
  from pydantic import BaseModel
 
 
 
14
 
15
- _ = load_dotenv()
16
-
17
  CO_API_KEY = os.getenv("COHERE_API_KEY")
18
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
19
 
 
20
  cohere_model = "command-a-03-2025"
 
 
21
 
22
- """##### We will build an Agent to write an essay by following the steps depicted
23
- in the graph below:
24
-
25
- <img src="Essay_Writer_Graph.JPG">
26
- """
27
-
28
-
29
- class AgentState(TypedDict):
30
- task: str # This is what we are trying to write the essay about
31
- plan: str # The plan that the planning agent will generate
32
- draft: str # Draft of the essat
33
- critique: str # Critique Agent will populate this key
34
- content: List[str] # List of documents that Tavili has researched.
35
- revision_number: int
36
- max_revisions: int
37
-
38
-
39
- model = ChatCohere(
40
- api_key=CO_API_KEY,
41
- model=cohere_model,
42
- )
43
-
44
- # This is the prompt for the LLM that will write the plan
45
  PLAN_PROMPT = """You are an expert writer tasked with writing a high level outline of an essay. \
46
  Write such an outline for the user provided topic. Give an outline of the essay along with any relevant notes \
47
  or instructions for the sections."""
48
 
49
- # This is the prompt for the LLM that will write the essay based on the
50
- # researched content
51
  WRITER_PROMPT = """You are an essay assistant tasked with writing excellent 5-paragraph essays.\
52
  Generate the best essay possible for the user's request and the initial outline. \
53
  If the user provides critique, respond with a revised version of your previous attempts. \
54
- Utilize all the information below as needed:
55
-
56
- ------
57
-
58
- {content}"""
59
 
60
- # The Reflection prompt will be used to cretique the essay
61
  REFLECTION_PROMPT = """You are a teacher grading an essay submission. \
62
  Generate critique and recommendations for the user's submission. \
63
  Provide detailed recommendations, including requests for length, depth, style, etc."""
64
 
65
- # This is the prompt for Researching after the planning step
66
- # Given a plan we will generate a bunch of queries and pass it to the Tivili for
67
- # Research
68
  RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \
69
  be used when writing the following essay. Generate a list of search queries that will gather \
70
  any relevant information. Only generate 3 queries max."""
71
 
72
- # This is a prompt that will generate new questions for Tivili baseds on the
73
- # critique of the research. This set of questions is based in the critiques, not
74
- # to be confused with the planning prompt which serves a similar purpose.
75
  RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \
76
  be used when making any requested revisions (as outlined below). \
77
  Generate a list of search queries that will gather any relevant information. Only generate 3 queries max."""
78
 
79
 
80
- class Queries(BaseModel):
81
- queries: List[str]
 
 
 
 
 
 
 
82
 
83
 
84
- tavily = TavilyClient(api_key=TAVILY_API_KEY)
 
85
 
86
 
87
- # Define the planning node. Prompt the LLM to develop a "plan"
88
  def plan_node(state: AgentState):
89
- messages = [
90
- SystemMessage(content=PLAN_PROMPT),
91
- HumanMessage(content=state['task'])
92
- ]
93
  response = model.invoke(messages)
94
  return {"plan": response.content}
95
 
@@ -99,7 +71,6 @@ def research_plan_node(state: AgentState):
99
  SystemMessage(content=RESEARCH_PLAN_PROMPT),
100
  HumanMessage(content=state['task'])
101
  ])
102
- #content = state['content'] or []
103
  content = state.get('content', [])
104
  for q in queries.queries:
105
  response = tavily.search(query=q, max_results=2)
@@ -108,16 +79,13 @@ def research_plan_node(state: AgentState):
108
  return {"content": content}
109
 
110
 
111
- # Generation node will write the first and subsequent drafts
112
  def generation_node(state: AgentState):
113
- #content = "\n\n".join(state['content'] or [])
114
  content = "\n\n".join(state.get('content', []))
115
- user_message = HumanMessage(
116
- content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}")
117
  messages = [
118
  SystemMessage(content=WRITER_PROMPT.format(content=content)),
119
  user_message
120
- ]
121
  response = model.invoke(messages)
122
  return {
123
  "draft": response.content,
@@ -126,34 +94,14 @@ def generation_node(state: AgentState):
126
 
127
 
128
  def reflection_node(state: AgentState):
129
- messages = [
130
- SystemMessage(content=REFLECTION_PROMPT),
131
- HumanMessage(content=state['draft'])
132
- ]
133
  response = model.invoke(messages)
134
  return {"critique": response.content}
135
 
136
- # def research_critique_node(state: AgentState):
137
- # queries = model.with_structured_output(Queries).invoke([
138
- # SystemMessage(content=RESEARCH_CRITIQUE_PROMPT),
139
- # HumanMessage(content=state['critique'])
140
- # ])
141
- # #content = state['content'] or []
142
- # content = state.get('content', [])
143
- # for q in queries.queries:
144
- # response = tavily.search(query=q, max_results=2)
145
- # for r in response['results']:
146
- # content.append(r['content'])
147
- # return {"content": content}
148
-
149
-
150
- # We should only send a HumanMessage(content=state['critique']) if
151
- # state['critique'] is not empty.
152
  def research_critique_node(state: AgentState):
153
  if not state.get('critique'):
154
- # Skip if there is no critique yet
155
  return {}
156
-
157
  queries = model.with_structured_output(Queries).invoke([
158
  SystemMessage(content=RESEARCH_CRITIQUE_PROMPT),
159
  HumanMessage(content=state['critique'])
@@ -166,14 +114,14 @@ def research_critique_node(state: AgentState):
166
  return {"content": content}
167
 
168
 
169
- def should_continue(state):
170
  if state["revision_number"] > state["max_revisions"]:
171
  return END
172
  return "reflect"
173
 
174
 
 
175
  builder = StateGraph(AgentState)
176
-
177
  builder.add_node("planner", plan_node)
178
  builder.add_node("generate", generation_node)
179
  builder.add_node("reflect", reflection_node)
@@ -181,37 +129,19 @@ builder.add_node("research_plan", research_plan_node)
181
  builder.add_node("research_critique", research_critique_node)
182
 
183
  builder.set_entry_point("planner")
184
-
185
- builder.add_conditional_edges(
186
- "generate",
187
- should_continue,
188
- {END: END, "reflect": "reflect"}
189
- )
190
-
191
  builder.add_edge("planner", "research_plan")
192
  builder.add_edge("research_plan", "generate")
193
-
194
  builder.add_edge("reflect", "research_critique")
195
  builder.add_edge("research_critique", "generate")
196
 
197
-
198
- from contextlib import ExitStack
199
  stack = ExitStack()
200
  checkpointer = stack.enter_context(SqliteSaver.from_conn_string(":memory:"))
201
  graph = builder.compile(checkpointer=checkpointer)
202
 
203
- #from IPython.display import Image
204
- #Image(graph.get_graph().draw_png())
205
-
206
- from PIL import Image as PILImage
207
- from io import BytesIO
208
- image_bytes = graph.get_graph().draw_png()
209
- img = PILImage.open(BytesIO(image_bytes))
210
- img.show()
211
-
212
 
 
213
  def create_initial_state(overrides: dict = None) -> dict:
214
- # Default initial blank state
215
  state = {
216
  "task": "",
217
  "plan": "",
@@ -226,28 +156,66 @@ def create_initial_state(overrides: dict = None) -> dict:
226
  return state
227
 
228
 
229
- thread = {"configurable": {"thread_id": "1"}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- initial_state = create_initial_state({
232
- 'task': "what is the difference between langchain and langsmith",
233
- "max_revisions": 2,
234
- "revision_number": 1,
235
- })
236
 
237
- import textwrap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
- for s in graph.stream(initial_state, thread):
240
- for k, v in s.items():
241
- print(f"\n--- {k.upper()} ---")
242
- if isinstance(v, dict):
243
- for subkey, value in v.items():
244
- if isinstance(value, str):
245
- print(f"{subkey}:\n{textwrap.fill(value, width=100)}\n")
246
- elif isinstance(value, list):
247
- print(f"{subkey}:")
248
- for i, item in enumerate(value, 1):
249
- print(f" [{i}] {textwrap.fill(str(item), width=100)}\n")
250
- else:
251
- print(f"{subkey}: {value}")
252
- else:
253
- print(textwrap.fill(str(v), width=100))
 
2
  from dotenv import load_dotenv
3
  from langgraph.graph import StateGraph, END
4
  from langgraph.checkpoint.sqlite import SqliteSaver
5
+ from typing import List, TypedDict
6
+ from langchain_core.messages import SystemMessage, HumanMessage
 
 
 
 
7
  from langchain_cohere import ChatCohere
8
  from tavily import TavilyClient
9
  from pydantic import BaseModel
10
+ import textwrap
11
+ import gradio as gr
12
+ from contextlib import ExitStack
13
 
14
+ # ========== ENVIRONMENT SETUP ==========
15
+ load_dotenv()
16
  CO_API_KEY = os.getenv("COHERE_API_KEY")
17
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
18
 
19
+ # ========== MODEL AND CLIENT SETUP ==========
20
  cohere_model = "command-a-03-2025"
21
+ model = ChatCohere(api_key=CO_API_KEY, model=cohere_model)
22
+ tavily = TavilyClient(api_key=TAVILY_API_KEY)
23
 
24
+ # ========== PROMPTS ==========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  PLAN_PROMPT = """You are an expert writer tasked with writing a high level outline of an essay. \
26
  Write such an outline for the user provided topic. Give an outline of the essay along with any relevant notes \
27
  or instructions for the sections."""
28
 
 
 
29
  WRITER_PROMPT = """You are an essay assistant tasked with writing excellent 5-paragraph essays.\
30
  Generate the best essay possible for the user's request and the initial outline. \
31
  If the user provides critique, respond with a revised version of your previous attempts. \
32
+ Utilize all the information below as needed:\n\n------\n\n{content}"""
 
 
 
 
33
 
 
34
  REFLECTION_PROMPT = """You are a teacher grading an essay submission. \
35
  Generate critique and recommendations for the user's submission. \
36
  Provide detailed recommendations, including requests for length, depth, style, etc."""
37
 
 
 
 
38
  RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \
39
  be used when writing the following essay. Generate a list of search queries that will gather \
40
  any relevant information. Only generate 3 queries max."""
41
 
 
 
 
42
  RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \
43
  be used when making any requested revisions (as outlined below). \
44
  Generate a list of search queries that will gather any relevant information. Only generate 3 queries max."""
45
 
46
 
47
+ # ========== STATE CLASS ==========
48
+ class AgentState(TypedDict):
49
+ task: str
50
+ plan: str
51
+ draft: str
52
+ critique: str
53
+ content: List[str]
54
+ revision_number: int
55
+ max_revisions: int
56
 
57
 
58
+ class Queries(BaseModel):
59
+ queries: List[str]
60
 
61
 
62
+ # ========== NODES ==========
63
  def plan_node(state: AgentState):
64
+ messages = [SystemMessage(content=PLAN_PROMPT), HumanMessage(content=state['task'])]
 
 
 
65
  response = model.invoke(messages)
66
  return {"plan": response.content}
67
 
 
71
  SystemMessage(content=RESEARCH_PLAN_PROMPT),
72
  HumanMessage(content=state['task'])
73
  ])
 
74
  content = state.get('content', [])
75
  for q in queries.queries:
76
  response = tavily.search(query=q, max_results=2)
 
79
  return {"content": content}
80
 
81
 
 
82
  def generation_node(state: AgentState):
 
83
  content = "\n\n".join(state.get('content', []))
84
+ user_message = HumanMessage(content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}")
 
85
  messages = [
86
  SystemMessage(content=WRITER_PROMPT.format(content=content)),
87
  user_message
88
+ ]
89
  response = model.invoke(messages)
90
  return {
91
  "draft": response.content,
 
94
 
95
 
96
  def reflection_node(state: AgentState):
97
+ messages = [SystemMessage(content=REFLECTION_PROMPT), HumanMessage(content=state['draft'])]
 
 
 
98
  response = model.invoke(messages)
99
  return {"critique": response.content}
100
 
101
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def research_critique_node(state: AgentState):
103
  if not state.get('critique'):
 
104
  return {}
 
105
  queries = model.with_structured_output(Queries).invoke([
106
  SystemMessage(content=RESEARCH_CRITIQUE_PROMPT),
107
  HumanMessage(content=state['critique'])
 
114
  return {"content": content}
115
 
116
 
117
+ def should_continue(state: AgentState):
118
  if state["revision_number"] > state["max_revisions"]:
119
  return END
120
  return "reflect"
121
 
122
 
123
+ # ========== GRAPH DEFINITION ==========
124
  builder = StateGraph(AgentState)
 
125
  builder.add_node("planner", plan_node)
126
  builder.add_node("generate", generation_node)
127
  builder.add_node("reflect", reflection_node)
 
129
  builder.add_node("research_critique", research_critique_node)
130
 
131
  builder.set_entry_point("planner")
132
+ builder.add_conditional_edges("generate", should_continue, {END: END, "reflect": "reflect"})
 
 
 
 
 
 
133
  builder.add_edge("planner", "research_plan")
134
  builder.add_edge("research_plan", "generate")
 
135
  builder.add_edge("reflect", "research_critique")
136
  builder.add_edge("research_critique", "generate")
137
 
 
 
138
  stack = ExitStack()
139
  checkpointer = stack.enter_context(SqliteSaver.from_conn_string(":memory:"))
140
  graph = builder.compile(checkpointer=checkpointer)
141
 
 
 
 
 
 
 
 
 
 
142
 
143
+ # ========== INITIAL STATE FUNCTION ==========
144
  def create_initial_state(overrides: dict = None) -> dict:
 
145
  state = {
146
  "task": "",
147
  "plan": "",
 
156
  return state
157
 
158
 
159
+ # ========== GRAPH EXECUTION ==========
160
+ def run_graph_with_topic(topic, max_revisions=2):
161
+ thread = {"configurable": {"thread_id": "1"}}
162
+ state = create_initial_state({
163
+ "task": topic,
164
+ "max_revisions": max_revisions,
165
+ "revision_number": 1
166
+ })
167
+
168
+ output_log = ""
169
+ final_draft = ""
170
+
171
+ for s in graph.stream(state, thread):
172
+ for k, v in s.items():
173
+ output_log += f"\n--- {k.upper()} ---\n"
174
+ if isinstance(v, dict):
175
+ for subkey, value in v.items():
176
+ if isinstance(value, str):
177
+ output_log += f"{subkey}:\n{textwrap.fill(value, width=100)}\n\n"
178
+ if subkey == "draft":
179
+ final_draft = value
180
+ elif isinstance(value, list):
181
+ output_log += f"{subkey}:\n"
182
+ for i, item in enumerate(value, 1):
183
+ output_log += f" [{i}] {textwrap.fill(str(item), width=100)}\n"
184
+ else:
185
+ output_log += f"{subkey}: {value}\n"
186
+ else:
187
+ output_log += textwrap.fill(str(v), width=100) + "\n"
188
+
189
+ # Stream intermediate log update
190
+ yield {
191
+ output_log_box: gr.update(value=output_log),
192
+ final_draft_box: gr.update(value="") # Clear draft until end
193
+ }
194
+
195
+ # Final result
196
+ yield {
197
+ output_log_box: gr.update(value=output_log),
198
+ final_draft_box: gr.update(value=final_draft)
199
+ }
200
 
 
 
 
 
 
201
 
202
+ # ========== GRADIO INTERFACE ==========
203
+ with gr.Blocks() as demo:
204
+ gr.Markdown("## ✍️ LangGraph Essay Writer\nEnter a topic and generate a researched, revised essay.")
205
+
206
+ with gr.Row():
207
+ topic_input = gr.Textbox(label="Essay Topic", placeholder="e.g., What is the impact of AI on jobs?")
208
+ max_rev_input = gr.Slider(1, 5, value=2, step=1, label="Max Revisions")
209
+
210
+ run_button = gr.Button("Generate Essay")
211
+
212
+ with gr.Row():
213
+ output_log_box = gr.Textbox(label="Agent Process Log", lines=20, interactive=False)
214
+ final_draft_box = gr.Textbox(label="Final Essay Draft", lines=10, interactive=False)
215
+
216
+ # This is the corrected streaming connection
217
+ run_button.click(fn=run_graph_with_topic,
218
+ inputs=[topic_input, max_rev_input],
219
+ outputs=[output_log_box, final_draft_box])
220
 
221
+ demo.launch()