NaderAfshar commited on
Commit
a7b24ed
·
1 Parent(s): 6f23954

Initial commit of files

Browse files
Files changed (1) hide show
  1. helper.py +456 -0
helper.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from dotenv import load_dotenv
3
+
4
+ import os
5
+ import gradio as gr
6
+
7
+ from langgraph.graph import StateGraph, END
8
+ from langgraph.checkpoint.sqlite import SqliteSaver
9
+ from typing import List, TypedDict, Annotated
10
+ import operator
11
+ from langchain_core.messages import (AnyMessage,
12
+ SystemMessage,
13
+ HumanMessage,
14
+ ToolMessage,
15
+ AIMessage )
16
+ from langchain_cohere import ChatCohere
17
+ from tavily import TavilyClient
18
+ from pydantic import BaseModel
19
+ import sqlite3
20
+
21
+ warnings.filterwarnings("ignore", message=".*TqdmWarning.*")
22
+ _ = load_dotenv()
23
+
24
+ CO_API_KEY = os.getenv("COHERE_API_KEY")
25
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
26
+
27
+ cohere_model = "command-a-03-2025"
28
+
29
+
30
+ class AgentState(TypedDict):
31
+ task: str
32
+ lnode: str
33
+ plan: str
34
+ draft: str
35
+ critique: str
36
+ content: List[str]
37
+ queries: List[str]
38
+ revision_number: int
39
+ max_revisions: int
40
+ count: Annotated[int, operator.add]
41
+
42
+
43
+ class Queries(BaseModel):
44
+ queries: List[str]
45
+
46
+
47
+ class ewriter():
48
+ def __init__(self):
49
+ self.model = ChatCohere(api_key=CO_API_KEY, model=cohere_model, temperature=0)
50
+
51
+ self.PLAN_PROMPT = ("You are an expert writer tasked with writing a high level outline of a short 3 paragraph essay. "
52
+ "Write such an outline for the user provided topic. Give the three main headers of an outline of "
53
+ "the essay along with any relevant notes or instructions for the sections. ")
54
+ self.WRITER_PROMPT = ("You are an essay assistant tasked with writing excellent 3 paragraph essays. "
55
+ "Generate the best essay possible for the user's request and the initial outline. "
56
+ "If the user provides critique, respond with a revised version of your previous attempts. "
57
+ "Utilize all the information below as needed: \n"
58
+ "------\n"
59
+ "{content}")
60
+ self.RESEARCH_PLAN_PROMPT = ("You are a researcher charged with providing information that can "
61
+ "be used when writing the following essay. Generate a list of search "
62
+ "queries that will gather "
63
+ "any relevant information. Only generate 3 queries max.")
64
+ self.REFLECTION_PROMPT = ("You are a teacher grading an 3 paragraph essay submission. "
65
+ "Generate critique and recommendations for the user's submission. "
66
+ "Provide detailed recommendations, including requests for length, depth, style, etc.")
67
+ self.RESEARCH_CRITIQUE_PROMPT = ("You are a researcher charged with providing information that can "
68
+ "be used when making any requested revisions (as outlined below). "
69
+ "Generate a list of search queries that will gather any relevant information. "
70
+ "Only generate 2 queries max.")
71
+
72
+ self.tavily = TavilyClient(api_key=TAVILY_API_KEY)
73
+
74
+ builder = StateGraph(AgentState)
75
+ builder.add_node("planner", self.plan_node)
76
+ builder.add_node("research_plan", self.research_plan_node)
77
+ builder.add_node("generate", self.generation_node)
78
+ builder.add_node("reflect", self.reflection_node)
79
+ builder.add_node("research_critique", self.research_critique_node)
80
+ builder.set_entry_point("planner")
81
+ builder.add_conditional_edges(
82
+ "generate",
83
+ self.should_continue,
84
+ {END: END, "reflect": "reflect"}
85
+ )
86
+ builder.add_edge("planner", "research_plan")
87
+ builder.add_edge("research_plan", "generate")
88
+ builder.add_edge("reflect", "research_critique")
89
+ builder.add_edge("research_critique", "generate")
90
+ memory = SqliteSaver(conn=sqlite3.connect(":memory:", check_same_thread=False))
91
+ self.graph = builder.compile(
92
+ checkpointer=memory,
93
+ interrupt_after=['planner', 'generate', 'reflect', 'research_plan', 'research_critique']
94
+ )
95
+
96
+ def plan_node(self, state: AgentState):
97
+ messages = [
98
+ SystemMessage(content=self.PLAN_PROMPT),
99
+ HumanMessage(content=state['task'])
100
+ ]
101
+ response = self.model.invoke(messages)
102
+ return {"plan": response.content,
103
+ "lnode": "planner",
104
+ "count": 1,
105
+ }
106
+
107
+ def research_plan_node(self, state: AgentState):
108
+ queries = self.model.with_structured_output(Queries).invoke([
109
+ SystemMessage(content=self.RESEARCH_PLAN_PROMPT),
110
+ HumanMessage(content=state['task'])
111
+ ])
112
+ content = state['content'] or [] # add to content
113
+ for q in queries.queries:
114
+ response = self.tavily.search(query=q, max_results=2)
115
+ for r in response['results']:
116
+ content.append(r['content'])
117
+ return {"content": content,
118
+ "queries": queries.queries,
119
+ "lnode": "research_plan",
120
+ "count": 1,
121
+ }
122
+
123
+ def generation_node(self, state: AgentState):
124
+ content = "\n\n".join(state['content'] or [])
125
+ user_message = HumanMessage(
126
+ content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}")
127
+ messages = [
128
+ SystemMessage(
129
+ content=self.WRITER_PROMPT.format(content=content)
130
+ ),
131
+ user_message
132
+ ]
133
+ response = self.model.invoke(messages)
134
+ return {
135
+ "draft": response.content,
136
+ "revision_number": state.get("revision_number", 1) + 1,
137
+ "lnode": "generate",
138
+ "count": 1,
139
+ }
140
+
141
+ def reflection_node(self, state: AgentState):
142
+ messages = [
143
+ SystemMessage(content=self.REFLECTION_PROMPT),
144
+ HumanMessage(content=state['draft'])
145
+ ]
146
+ response = self.model.invoke(messages)
147
+ return {"critique": response.content,
148
+ "lnode": "reflect",
149
+ "count": 1,
150
+ }
151
+
152
+ def research_critique_node(self, state: AgentState):
153
+ queries = self.model.with_structured_output(Queries).invoke([
154
+ SystemMessage(content=self.RESEARCH_CRITIQUE_PROMPT),
155
+ HumanMessage(content=state['critique'])
156
+ ])
157
+ content = state['content'] or []
158
+ for q in queries.queries:
159
+ response = self.tavily.search(query=q, max_results=2)
160
+ for r in response['results']:
161
+ content.append(r['content'])
162
+ return {"content": content,
163
+ "lnode": "research_critique",
164
+ "count": 1,
165
+ }
166
+
167
+ def should_continue(self, state):
168
+ if state["revision_number"] > state["max_revisions"]:
169
+ return END
170
+ return "reflect"
171
+
172
+
173
+
174
+ class writer_gui( ):
175
+ def __init__(self, graph, share=False):
176
+ self.graph = graph
177
+ self.share = share
178
+ self.partial_message = ""
179
+ self.response = {}
180
+ self.max_iterations = 10
181
+ self.iterations = []
182
+ self.threads = []
183
+ self.thread_id = -1
184
+ self.thread = {"configurable": {"thread_id": str(self.thread_id)}}
185
+ #self.sdisps = {} #global
186
+ self.demo = self.create_interface()
187
+
188
+ def run_agent(self, start,topic,stop_after):
189
+ #global partial_message, thread_id,thread
190
+ #global response, max_iterations, iterations, threads
191
+ if start:
192
+ self.iterations.append(0)
193
+ config = {'task': topic,"max_revisions": 2,"revision_number": 0,
194
+ 'lnode': "", 'planner': "no plan", 'draft': "no draft", 'critique': "no critique",
195
+ 'content': ["no content",], 'queries': "no queries", 'count':0}
196
+ self.thread_id += 1 # new agent, new thread
197
+ self.threads.append(self.thread_id)
198
+ else:
199
+ config = None
200
+
201
+ self.thread = {"configurable": {"thread_id": str(self.thread_id)}}
202
+
203
+ while self.iterations[self.thread_id] < self.max_iterations:
204
+ self.response = self.graph.invoke(config, self.thread)
205
+ self.iterations[self.thread_id] += 1
206
+ self.partial_message += str(self.response)
207
+ self.partial_message += f"\n------------------\n\n"
208
+ ## fix
209
+ lnode, nnode, _, rev, acount = self.get_disp_state()
210
+ yield self.partial_message, lnode, nnode, self.thread_id, rev, acount
211
+ config = None #need
212
+ #print(f"run_agent:{lnode}")
213
+ if not nnode:
214
+ #print("Hit the end")
215
+ return
216
+ if lnode in stop_after:
217
+ #print(f"stopping due to stop_after {lnode}")
218
+ return
219
+ else:
220
+ #print(f"Not stopping on lnode {lnode}")
221
+ pass
222
+ return
223
+
224
+ def get_disp_state(self,):
225
+ current_state = self.graph.get_state(self.thread)
226
+ lnode = current_state.values["lnode"]
227
+ acount = current_state.values["count"]
228
+ rev = current_state.values["revision_number"]
229
+ nnode = current_state.next
230
+ #print (lnode,nnode,self.thread_id,rev,acount)
231
+ return lnode,nnode,self.thread_id,rev,acount
232
+
233
+ def get_state(self,key):
234
+ current_values = self.graph.get_state(self.thread)
235
+ if key in current_values.values:
236
+ lnode,nnode,self.thread_id,rev,astep = self.get_disp_state()
237
+ new_label = f"last_node: {lnode}, thread_id: {self.thread_id}, rev: {rev}, step: {astep}"
238
+ return gr.update(label=new_label, value=current_values.values[key])
239
+ else:
240
+ return ""
241
+
242
+ def get_content(self,):
243
+ current_values = self.graph.get_state(self.thread)
244
+ if "content" in current_values.values:
245
+ content = current_values.values["content"]
246
+ lnode,nnode,thread_id,rev,astep = self.get_disp_state()
247
+ new_label = f"last_node: {lnode}, thread_id: {self.thread_id}, rev: {rev}, step: {astep}"
248
+ return gr.update(label=new_label, value="\n\n".join(item for item in content) + "\n\n")
249
+ else:
250
+ return ""
251
+
252
+ def update_hist_pd(self,):
253
+ #print("update_hist_pd")
254
+ hist = []
255
+ # curiously, this generator returns the latest first
256
+ for state in self.graph.get_state_history(self.thread):
257
+ if state.metadata['step'] < 1:
258
+ continue
259
+ thread_ts = state.config['configurable']['thread_ts']
260
+ tid = state.config['configurable']['thread_id']
261
+ count = state.values['count']
262
+ lnode = state.values['lnode']
263
+ rev = state.values['revision_number']
264
+ nnode = state.next
265
+ st = f"{tid}:{count}:{lnode}:{nnode}:{rev}:{thread_ts}"
266
+ hist.append(st)
267
+ return gr.Dropdown(label="update_state from: thread:count:last_node:next_node:rev:thread_ts",
268
+ choices=hist, value=hist[0],interactive=True)
269
+
270
+ def find_config(self,thread_ts):
271
+ for state in self.graph.get_state_history(self.thread):
272
+ config = state.config
273
+ if config['configurable']['thread_ts'] == thread_ts:
274
+ return config
275
+ return(None)
276
+
277
+ def copy_state(self, hist_str):
278
+ ''' result of selecting an old state from the step pulldown. Note does not change thread.
279
+ This copies an old state to a new current state.
280
+ '''
281
+ thread_ts = hist_str.split(":")[-1]
282
+ #print(f"copy_state from {thread_ts}")
283
+ config = self.find_config(thread_ts)
284
+ #print(config)
285
+ state = self.graph.get_state(config)
286
+ self.graph.update_state(self.thread, state.values, as_node=state.values['lnode'])
287
+ new_state = self.graph.get_state(self.thread) #should now match
288
+ new_thread_ts = new_state.config['configurable']['thread_ts']
289
+ tid = new_state.config['configurable']['thread_id']
290
+ count = new_state.values['count']
291
+ lnode = new_state.values['lnode']
292
+ rev = new_state.values['revision_number']
293
+ nnode = new_state.next
294
+ return lnode,nnode,new_thread_ts,rev,count
295
+
296
+ def update_thread_pd(self,):
297
+ #print("update_thread_pd")
298
+ return gr.Dropdown(label="choose thread", choices=self.threads, value=str(self.thread_id), interactive=True)
299
+
300
+ def switch_thread(self,new_thread_id):
301
+ #print(f"switch_thread{new_thread_id}")
302
+ self.thread = {"configurable": {"thread_id": str(new_thread_id)}}
303
+ self.thread_id = new_thread_id
304
+ return
305
+
306
+ def modify_state(self,key,asnode,new_state):
307
+ ''' gets the current state, modifes a single value in the state identified by key, and updates state with it.
308
+ note that this will create a new 'current state' node. If you do this multiple times with different keys, it will create
309
+ one for each update. Note also that it doesn't resume after the update
310
+ '''
311
+ current_values = self.graph.get_state(self.thread)
312
+ current_values.values[key] = new_state
313
+ self.graph.update_state(self.thread, current_values.values,as_node=asnode)
314
+ return
315
+
316
+ def create_interface(self):
317
+ with gr.Blocks(theme=gr.themes.Default(spacing_size='sm',text_size="sm")) as demo:
318
+
319
+ def updt_disp():
320
+ ''' general update display on state change '''
321
+ current_state = self.graph.get_state(self.thread)
322
+ hist = []
323
+ # curiously, this generator returns the latest first
324
+ for state in self.graph.get_state_history(self.thread):
325
+ if state.metadata['step'] < 1: #ignore early states
326
+ continue
327
+ s_thread_ts = state.config['configurable']['thread_ts']
328
+ s_tid = state.config['configurable']['thread_id']
329
+ s_count = state.values['count']
330
+ s_lnode = state.values['lnode']
331
+ s_rev = state.values['revision_number']
332
+ s_nnode = state.next
333
+ st = f"{s_tid}:{s_count}:{s_lnode}:{s_nnode}:{s_rev}:{s_thread_ts}"
334
+ hist.append(st)
335
+ if not current_state.metadata: #handle init call
336
+ return{}
337
+ else:
338
+ return {
339
+ topic_bx : current_state.values["task"],
340
+ lnode_bx : current_state.values["lnode"],
341
+ count_bx : current_state.values["count"],
342
+ revision_bx : current_state.values["revision_number"],
343
+ nnode_bx : current_state.next,
344
+ threadid_bx : self.thread_id,
345
+ thread_pd : gr.Dropdown(label="choose thread", choices=self.threads,
346
+ value=str(self.thread_id), interactive=True),
347
+ step_pd : gr.Dropdown(label="update_state from: thread:count:last_node:next_node:rev:thread_ts",
348
+ choices=hist, value=hist[0], interactive=True),
349
+ }
350
+
351
+ def get_snapshots():
352
+ new_label = f"thread_id: {self.thread_id}, Summary of snapshots"
353
+ sstate = ""
354
+ for state in self.graph.get_state_history(self.thread):
355
+ for key in ['plan', 'draft', 'critique']:
356
+ if key in state.values:
357
+ state.values[key] = state.values[key][:80] + "..."
358
+ if 'content' in state.values:
359
+ for i in range(len(state.values['content'])):
360
+ state.values['content'][i] = state.values['content'][i][:20] + '...'
361
+ if 'writes' in state.metadata:
362
+ state.metadata['writes'] = "not shown"
363
+ sstate += str(state) + "\n\n"
364
+ return gr.update(label=new_label, value=sstate)
365
+
366
+ def vary_btn(stat):
367
+ #print(f"vary_btn{stat}")
368
+ return(gr.update(variant=stat))
369
+
370
+ with gr.Tab("Agent"):
371
+ with gr.Row():
372
+ topic_bx = gr.Textbox(label="Essay Topic", value="Pizza Shop")
373
+ gen_btn = gr.Button("Generate Essay", scale=0,min_width=80, variant='primary')
374
+ cont_btn = gr.Button("Continue Essay", scale=0,min_width=80)
375
+ with gr.Row():
376
+ lnode_bx = gr.Textbox(label="last node", min_width=100)
377
+ nnode_bx = gr.Textbox(label="next node", min_width=100)
378
+ threadid_bx = gr.Textbox(label="Thread", scale=0, min_width=80)
379
+ revision_bx = gr.Textbox(label="Draft Rev", scale=0, min_width=80)
380
+ count_bx = gr.Textbox(label="count", scale=0, min_width=80)
381
+ with gr.Accordion("Manage Agent", open=False):
382
+ checks = list(self.graph.nodes.keys())
383
+ checks.remove('__start__')
384
+ stop_after = gr.CheckboxGroup(checks,label="Interrupt After State", value=checks, scale=0, min_width=400)
385
+ with gr.Row():
386
+ thread_pd = gr.Dropdown(choices=self.threads,interactive=True, label="select thread", min_width=120, scale=0)
387
+ step_pd = gr.Dropdown(choices=['N/A'],interactive=True, label="select step", min_width=160, scale=1)
388
+ live = gr.Textbox(label="Live Agent Output", lines=5, max_lines=5)
389
+
390
+ # actions
391
+ sdisps =[topic_bx,lnode_bx,nnode_bx,threadid_bx,revision_bx,count_bx,step_pd,thread_pd]
392
+ thread_pd.input(self.switch_thread, [thread_pd], None).then(
393
+ fn=updt_disp, inputs=None, outputs=sdisps)
394
+
395
+ step_pd.input(self.copy_state,[step_pd],None).then(
396
+ fn=updt_disp, inputs=None, outputs=sdisps)
397
+
398
+ gen_btn.click(vary_btn, gr.Number(label="secondary", visible=False), gen_btn).then(
399
+ fn=self.run_agent, inputs=[gr.Number(True, visible=False),topic_bx,stop_after], outputs=[live],show_progress=True).then(
400
+ fn=updt_disp, inputs=None, outputs=sdisps).then(
401
+ vary_btn,gr.Number(label="primary", visible=False), gen_btn).then(
402
+ vary_btn,gr.Number(label="primary", visible=False), cont_btn)
403
+
404
+ cont_btn.click(vary_btn,gr.Number(label="secondary", visible=False), cont_btn).then(
405
+ fn=self.run_agent, inputs=[gr.Number(False, visible=False),topic_bx,stop_after],
406
+ outputs=[live]).then(
407
+ fn=updt_disp, inputs=None, outputs=sdisps).then(
408
+ vary_btn,gr.Number(label="primary", visible=False), cont_btn)
409
+
410
+ with gr.Tab("Plan"):
411
+ with gr.Row():
412
+ refresh_btn = gr.Button("Refresh")
413
+ modify_btn = gr.Button("Modify")
414
+ plan = gr.Textbox(label="Plan", lines=10, interactive=True)
415
+ refresh_btn.click(fn=self.get_state, inputs=gr.Number(label="plan", visible=False), outputs=plan)
416
+ modify_btn.click(fn=self.modify_state, inputs=[gr.Number(label="plan", visible=False),
417
+ gr.Number(label="planner", visible=False), plan],outputs=None).then(
418
+ fn=updt_disp, inputs=None, outputs=sdisps)
419
+ with gr.Tab("Research Content"):
420
+ refresh_btn = gr.Button("Refresh")
421
+ content_bx = gr.Textbox(label="content", lines=10)
422
+ refresh_btn.click(fn=self.get_content, inputs=None, outputs=content_bx)
423
+ with gr.Tab("Draft"):
424
+ with gr.Row():
425
+ refresh_btn = gr.Button("Refresh")
426
+ modify_btn = gr.Button("Modify")
427
+ draft_bx = gr.Textbox(label="draft", lines=10, interactive=True)
428
+ refresh_btn.click(fn=self.get_state, inputs=gr.Number(label="draft", visible=False), outputs=draft_bx)
429
+ modify_btn.click(fn=self.modify_state, inputs=[gr.Number(label="draft", visible=False),
430
+ gr.Number(label="generate", visible=False),
431
+ draft_bx], outputs=None).then(
432
+ fn=updt_disp, inputs=None, outputs=sdisps)
433
+
434
+ with gr.Tab("Critique"):
435
+ with gr.Row():
436
+ refresh_btn = gr.Button("Refresh")
437
+ modify_btn = gr.Button("Modify")
438
+ critique_bx = gr.Textbox(label="Critique", lines=10, interactive=True)
439
+ refresh_btn.click(fn=self.get_state, inputs=gr.Number(label="critique", visible=False), outputs=critique_bx)
440
+ modify_btn.click(fn=self.modify_state, inputs=[gr.Number(label="critique", visible=False),
441
+ gr.Number(label="reflect", visible=False),
442
+ critique_bx], outputs=None).then(
443
+ fn=updt_disp, inputs=None, outputs=sdisps)
444
+
445
+ with gr.Tab("StateSnapShots"):
446
+ with gr.Row():
447
+ refresh_btn = gr.Button("Refresh")
448
+ snapshots = gr.Textbox(label="State Snapshots Summaries")
449
+ refresh_btn.click(fn=get_snapshots, inputs=None, outputs=snapshots)
450
+ return demo
451
+
452
+ def launch(self):
453
+ if port := os.getenv("PORT1"):
454
+ self.demo.launch(share=True, server_port=int(port), server_name="0.0.0.0")
455
+ else:
456
+ self.demo.launch(share=self.share)