NaderAfshar commited on
Commit
0169c4a
·
1 Parent(s): 42e73aa

renamed the essay_writer.py file to app.py

Browse files
Files changed (4) hide show
  1. .env +1 -1
  2. app.py +253 -0
  3. helper.py +19 -0
  4. requirements.txt +3 -0
.env CHANGED
@@ -1,3 +1,3 @@
1
  COHERE_API_KEY=p9Qnpw98wKgjWBBgiCW3JWBmskTkd6AL3kkutDYA
2
  TAVILY_API_KEY=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
3
- PORT1=8000
 
1
  COHERE_API_KEY=p9Qnpw98wKgjWBBgiCW3JWBmskTkd6AL3kkutDYA
2
  TAVILY_API_KEY=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
3
+ #PORT1=8000
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langgraph.graph import StateGraph, END
4
+ from langgraph.checkpoint.sqlite import SqliteSaver
5
+ from typing import List, TypedDict, Annotated
6
+ from langchain_core.messages import (AnyMessage,
7
+ SystemMessage,
8
+ HumanMessage,
9
+ ToolMessage,
10
+ AIMessage )
11
+ from langchain_cohere import ChatCohere
12
+ from tavily import TavilyClient
13
+ from pydantic import BaseModel
14
+
15
+ _ = load_dotenv()
16
+
17
+ CO_API_KEY = os.getenv("COHERE_API_KEY")
18
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
19
+
20
+ cohere_model = "command-a-03-2025"
21
+
22
+ """##### We will build an Agent to write an essay by following the steps depicted
23
+ in the graph below:
24
+
25
+ <img src="Essay_Writer_Graph.JPG">
26
+ """
27
+
28
+
29
+ class AgentState(TypedDict):
30
+ task: str # This is what we are trying to write the essay about
31
+ plan: str # The plan that the planning agent will generate
32
+ draft: str # Draft of the essat
33
+ critique: str # Critique Agent will populate this key
34
+ content: List[str] # List of documents that Tavili has researched.
35
+ revision_number: int
36
+ max_revisions: int
37
+
38
+
39
+ model = ChatCohere(
40
+ api_key=CO_API_KEY,
41
+ model=cohere_model,
42
+ )
43
+
44
+ # This is the prompt for the LLM that will write the plan
45
+ PLAN_PROMPT = """You are an expert writer tasked with writing a high level outline of an essay. \
46
+ Write such an outline for the user provided topic. Give an outline of the essay along with any relevant notes \
47
+ or instructions for the sections."""
48
+
49
+ # This is the prompt for the LLM that will write the essay based on the
50
+ # researched content
51
+ WRITER_PROMPT = """You are an essay assistant tasked with writing excellent 5-paragraph essays.\
52
+ Generate the best essay possible for the user's request and the initial outline. \
53
+ If the user provides critique, respond with a revised version of your previous attempts. \
54
+ Utilize all the information below as needed:
55
+
56
+ ------
57
+
58
+ {content}"""
59
+
60
+ # The Reflection prompt will be used to cretique the essay
61
+ REFLECTION_PROMPT = """You are a teacher grading an essay submission. \
62
+ Generate critique and recommendations for the user's submission. \
63
+ Provide detailed recommendations, including requests for length, depth, style, etc."""
64
+
65
+ # This is the prompt for Researching after the planning step
66
+ # Given a plan we will generate a bunch of queries and pass it to the Tivili for
67
+ # Research
68
+ RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \
69
+ be used when writing the following essay. Generate a list of search queries that will gather \
70
+ any relevant information. Only generate 3 queries max."""
71
+
72
+ # This is a prompt that will generate new questions for Tivili baseds on the
73
+ # critique of the research. This set of questions is based in the critiques, not
74
+ # to be confused with the planning prompt which serves a similar purpose.
75
+ RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \
76
+ be used when making any requested revisions (as outlined below). \
77
+ Generate a list of search queries that will gather any relevant information. Only generate 3 queries max."""
78
+
79
+
80
+ class Queries(BaseModel):
81
+ queries: List[str]
82
+
83
+
84
+ tavily = TavilyClient(api_key=TAVILY_API_KEY)
85
+
86
+
87
+ # Define the planning node. Prompt the LLM to develop a "plan"
88
+ def plan_node(state: AgentState):
89
+ messages = [
90
+ SystemMessage(content=PLAN_PROMPT),
91
+ HumanMessage(content=state['task'])
92
+ ]
93
+ response = model.invoke(messages)
94
+ return {"plan": response.content}
95
+
96
+
97
+ def research_plan_node(state: AgentState):
98
+ queries = model.with_structured_output(Queries).invoke([
99
+ SystemMessage(content=RESEARCH_PLAN_PROMPT),
100
+ HumanMessage(content=state['task'])
101
+ ])
102
+ #content = state['content'] or []
103
+ content = state.get('content', [])
104
+ for q in queries.queries:
105
+ response = tavily.search(query=q, max_results=2)
106
+ for r in response['results']:
107
+ content.append(r['content'])
108
+ return {"content": content}
109
+
110
+
111
+ # Generation node will write the first and subsequent drafts
112
+ def generation_node(state: AgentState):
113
+ #content = "\n\n".join(state['content'] or [])
114
+ content = "\n\n".join(state.get('content', []))
115
+ user_message = HumanMessage(
116
+ content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}")
117
+ messages = [
118
+ SystemMessage(content=WRITER_PROMPT.format(content=content)),
119
+ user_message
120
+ ]
121
+ response = model.invoke(messages)
122
+ return {
123
+ "draft": response.content,
124
+ "revision_number": state.get("revision_number", 1) + 1
125
+ }
126
+
127
+
128
+ def reflection_node(state: AgentState):
129
+ messages = [
130
+ SystemMessage(content=REFLECTION_PROMPT),
131
+ HumanMessage(content=state['draft'])
132
+ ]
133
+ response = model.invoke(messages)
134
+ return {"critique": response.content}
135
+
136
+ # def research_critique_node(state: AgentState):
137
+ # queries = model.with_structured_output(Queries).invoke([
138
+ # SystemMessage(content=RESEARCH_CRITIQUE_PROMPT),
139
+ # HumanMessage(content=state['critique'])
140
+ # ])
141
+ # #content = state['content'] or []
142
+ # content = state.get('content', [])
143
+ # for q in queries.queries:
144
+ # response = tavily.search(query=q, max_results=2)
145
+ # for r in response['results']:
146
+ # content.append(r['content'])
147
+ # return {"content": content}
148
+
149
+
150
+ # We should only send a HumanMessage(content=state['critique']) if
151
+ # state['critique'] is not empty.
152
+ def research_critique_node(state: AgentState):
153
+ if not state.get('critique'):
154
+ # Skip if there is no critique yet
155
+ return {}
156
+
157
+ queries = model.with_structured_output(Queries).invoke([
158
+ SystemMessage(content=RESEARCH_CRITIQUE_PROMPT),
159
+ HumanMessage(content=state['critique'])
160
+ ])
161
+ content = state['content'] or []
162
+ for q in queries.queries:
163
+ response = tavily.search(query=q, max_results=2)
164
+ for r in response['results']:
165
+ content.append(r['content'])
166
+ return {"content": content}
167
+
168
+
169
+ def should_continue(state):
170
+ if state["revision_number"] > state["max_revisions"]:
171
+ return END
172
+ return "reflect"
173
+
174
+
175
+ builder = StateGraph(AgentState)
176
+
177
+ builder.add_node("planner", plan_node)
178
+ builder.add_node("generate", generation_node)
179
+ builder.add_node("reflect", reflection_node)
180
+ builder.add_node("research_plan", research_plan_node)
181
+ builder.add_node("research_critique", research_critique_node)
182
+
183
+ builder.set_entry_point("planner")
184
+
185
+ builder.add_conditional_edges(
186
+ "generate",
187
+ should_continue,
188
+ {END: END, "reflect": "reflect"}
189
+ )
190
+
191
+ builder.add_edge("planner", "research_plan")
192
+ builder.add_edge("research_plan", "generate")
193
+
194
+ builder.add_edge("reflect", "research_critique")
195
+ builder.add_edge("research_critique", "generate")
196
+
197
+
198
+ from contextlib import ExitStack
199
+ stack = ExitStack()
200
+ checkpointer = stack.enter_context(SqliteSaver.from_conn_string(":memory:"))
201
+ graph = builder.compile(checkpointer=checkpointer)
202
+
203
+ #from IPython.display import Image
204
+ #Image(graph.get_graph().draw_png())
205
+
206
+ from PIL import Image as PILImage
207
+ from io import BytesIO
208
+ image_bytes = graph.get_graph().draw_png()
209
+ img = PILImage.open(BytesIO(image_bytes))
210
+ img.show()
211
+
212
+
213
+ def create_initial_state(overrides: dict = None) -> dict:
214
+ # Default initial blank state
215
+ state = {
216
+ "task": "",
217
+ "plan": "",
218
+ "draft": "",
219
+ "critique": "",
220
+ "content": [],
221
+ "revision_number": 0,
222
+ "max_revisions": 3
223
+ }
224
+ if overrides:
225
+ state.update(overrides)
226
+ return state
227
+
228
+
229
+ thread = {"configurable": {"thread_id": "1"}}
230
+
231
+ initial_state = create_initial_state({
232
+ 'task': "what is the difference between langchain and langsmith",
233
+ "max_revisions": 2,
234
+ "revision_number": 1,
235
+ })
236
+
237
+ import textwrap
238
+
239
+ for s in graph.stream(initial_state, thread):
240
+ for k, v in s.items():
241
+ print(f"\n--- {k.upper()} ---")
242
+ if isinstance(v, dict):
243
+ for subkey, value in v.items():
244
+ if isinstance(value, str):
245
+ print(f"{subkey}:\n{textwrap.fill(value, width=100)}\n")
246
+ elif isinstance(value, list):
247
+ print(f"{subkey}:")
248
+ for i, item in enumerate(value, 1):
249
+ print(f" [{i}] {textwrap.fill(str(item), width=100)}\n")
250
+ else:
251
+ print(f"{subkey}: {value}")
252
+ else:
253
+ print(textwrap.fill(str(v), width=100))
helper.py CHANGED
@@ -2,6 +2,7 @@ import warnings
2
  from dotenv import load_dotenv
3
 
4
  import os
 
5
  import gradio as gr
6
 
7
  from langgraph.graph import StateGraph, END
@@ -181,6 +182,7 @@ class writer_gui( ):
181
  self.iterations = []
182
  self.threads = []
183
  self.thread_id = -1
 
184
  self.thread = {"configurable": {"thread_id": str(self.thread_id)}}
185
  #self.sdisps = {} #global
186
  self.demo = self.create_interface()
@@ -195,6 +197,17 @@ class writer_gui( ):
195
  'content': ["no content",], 'queries': "no queries", 'count':0}
196
  self.thread_id += 1 # new agent, new thread
197
  self.threads.append(self.thread_id)
 
 
 
 
 
 
 
 
 
 
 
198
  else:
199
  config = None
200
 
@@ -454,3 +467,9 @@ class writer_gui( ):
454
  self.demo.launch(share=True, server_port=int(port), server_name="0.0.0.0")
455
  else:
456
  self.demo.launch(share=self.share)
 
 
 
 
 
 
 
2
  from dotenv import load_dotenv
3
 
4
  import os
5
+ import time
6
  import gradio as gr
7
 
8
  from langgraph.graph import StateGraph, END
 
182
  self.iterations = []
183
  self.threads = []
184
  self.thread_id = -1
185
+ self.thread_ts_map = {} # <----- ------>
186
  self.thread = {"configurable": {"thread_id": str(self.thread_id)}}
187
  #self.sdisps = {} #global
188
  self.demo = self.create_interface()
 
197
  'content': ["no content",], 'queries': "no queries", 'count':0}
198
  self.thread_id += 1 # new agent, new thread
199
  self.threads.append(self.thread_id)
200
+
201
+ # --------++++++>>
202
+ if self.thread_id not in self.thread_ts_map:
203
+ self.thread_ts_map[self.thread_id] = str(self.thread_id) # or use a stable UUID
204
+
205
+ self.thread = {"configurable": {
206
+ "thread_id": str(self.thread_id),
207
+ "thread_ts": self.thread_ts_map[self.thread_id],
208
+ }}
209
+ # --------++++++>>
210
+
211
  else:
212
  config = None
213
 
 
467
  self.demo.launch(share=True, server_port=int(port), server_name="0.0.0.0")
468
  else:
469
  self.demo.launch(share=self.share)
470
+
471
+
472
+ if __name__ == "__main__":
473
+ agent = ewriter()
474
+ gui = writer_gui(agent.graph)
475
+ gui.launch()
requirements.txt CHANGED
@@ -9,4 +9,7 @@ langgraph-checkpoint
9
  langgraph-checkpoint-sqlite
10
  aiosqlite
11
  dotenv
 
 
 
12
  gradio
 
9
  langgraph-checkpoint-sqlite
10
  aiosqlite
11
  dotenv
12
+ IPython
13
+ pillow
14
+ pygraphviz
15
  gradio