NaderAfshar commited on
Commit
664e9ba
·
1 Parent(s): 73c1ece

Remove helper.py from repository but keep locally

Browse files
Files changed (1) hide show
  1. helper.py +0 -475
helper.py DELETED
@@ -1,475 +0,0 @@
1
- import warnings
2
- from dotenv import load_dotenv
3
-
4
- import os
5
- import time
6
- import gradio as gr
7
-
8
- from langgraph.graph import StateGraph, END
9
- from langgraph.checkpoint.sqlite import SqliteSaver
10
- from typing import List, TypedDict, Annotated
11
- import operator
12
- from langchain_core.messages import (AnyMessage,
13
- SystemMessage,
14
- HumanMessage,
15
- ToolMessage,
16
- AIMessage )
17
- from langchain_cohere import ChatCohere
18
- from tavily import TavilyClient
19
- from pydantic import BaseModel
20
- import sqlite3
21
-
22
- warnings.filterwarnings("ignore", message=".*TqdmWarning.*")
23
- _ = load_dotenv()
24
-
25
- CO_API_KEY = os.getenv("COHERE_API_KEY")
26
- TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
27
-
28
- cohere_model = "command-a-03-2025"
29
-
30
-
31
- class AgentState(TypedDict):
32
- task: str
33
- lnode: str
34
- plan: str
35
- draft: str
36
- critique: str
37
- content: List[str]
38
- queries: List[str]
39
- revision_number: int
40
- max_revisions: int
41
- count: Annotated[int, operator.add]
42
-
43
-
44
- class Queries(BaseModel):
45
- queries: List[str]
46
-
47
-
48
- class ewriter():
49
- def __init__(self):
50
- self.model = ChatCohere(api_key=CO_API_KEY, model=cohere_model, temperature=0)
51
-
52
- self.PLAN_PROMPT = ("You are an expert writer tasked with writing a high level outline of a short 3 paragraph essay. "
53
- "Write such an outline for the user provided topic. Give the three main headers of an outline of "
54
- "the essay along with any relevant notes or instructions for the sections. ")
55
- self.WRITER_PROMPT = ("You are an essay assistant tasked with writing excellent 3 paragraph essays. "
56
- "Generate the best essay possible for the user's request and the initial outline. "
57
- "If the user provides critique, respond with a revised version of your previous attempts. "
58
- "Utilize all the information below as needed: \n"
59
- "------\n"
60
- "{content}")
61
- self.RESEARCH_PLAN_PROMPT = ("You are a researcher charged with providing information that can "
62
- "be used when writing the following essay. Generate a list of search "
63
- "queries that will gather "
64
- "any relevant information. Only generate 3 queries max.")
65
- self.REFLECTION_PROMPT = ("You are a teacher grading an 3 paragraph essay submission. "
66
- "Generate critique and recommendations for the user's submission. "
67
- "Provide detailed recommendations, including requests for length, depth, style, etc.")
68
- self.RESEARCH_CRITIQUE_PROMPT = ("You are a researcher charged with providing information that can "
69
- "be used when making any requested revisions (as outlined below). "
70
- "Generate a list of search queries that will gather any relevant information. "
71
- "Only generate 2 queries max.")
72
-
73
- self.tavily = TavilyClient(api_key=TAVILY_API_KEY)
74
-
75
- builder = StateGraph(AgentState)
76
- builder.add_node("planner", self.plan_node)
77
- builder.add_node("research_plan", self.research_plan_node)
78
- builder.add_node("generate", self.generation_node)
79
- builder.add_node("reflect", self.reflection_node)
80
- builder.add_node("research_critique", self.research_critique_node)
81
- builder.set_entry_point("planner")
82
- builder.add_conditional_edges(
83
- "generate",
84
- self.should_continue,
85
- {END: END, "reflect": "reflect"}
86
- )
87
- builder.add_edge("planner", "research_plan")
88
- builder.add_edge("research_plan", "generate")
89
- builder.add_edge("reflect", "research_critique")
90
- builder.add_edge("research_critique", "generate")
91
- memory = SqliteSaver(conn=sqlite3.connect(":memory:", check_same_thread=False))
92
- self.graph = builder.compile(
93
- checkpointer=memory,
94
- interrupt_after=['planner', 'generate', 'reflect', 'research_plan', 'research_critique']
95
- )
96
-
97
- def plan_node(self, state: AgentState):
98
- messages = [
99
- SystemMessage(content=self.PLAN_PROMPT),
100
- HumanMessage(content=state['task'])
101
- ]
102
- response = self.model.invoke(messages)
103
- return {"plan": response.content,
104
- "lnode": "planner",
105
- "count": 1,
106
- }
107
-
108
- def research_plan_node(self, state: AgentState):
109
- queries = self.model.with_structured_output(Queries).invoke([
110
- SystemMessage(content=self.RESEARCH_PLAN_PROMPT),
111
- HumanMessage(content=state['task'])
112
- ])
113
- content = state['content'] or [] # add to content
114
- for q in queries.queries:
115
- response = self.tavily.search(query=q, max_results=2)
116
- for r in response['results']:
117
- content.append(r['content'])
118
- return {"content": content,
119
- "queries": queries.queries,
120
- "lnode": "research_plan",
121
- "count": 1,
122
- }
123
-
124
- def generation_node(self, state: AgentState):
125
- content = "\n\n".join(state['content'] or [])
126
- user_message = HumanMessage(
127
- content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}")
128
- messages = [
129
- SystemMessage(
130
- content=self.WRITER_PROMPT.format(content=content)
131
- ),
132
- user_message
133
- ]
134
- response = self.model.invoke(messages)
135
- return {
136
- "draft": response.content,
137
- "revision_number": state.get("revision_number", 1) + 1,
138
- "lnode": "generate",
139
- "count": 1,
140
- }
141
-
142
- def reflection_node(self, state: AgentState):
143
- messages = [
144
- SystemMessage(content=self.REFLECTION_PROMPT),
145
- HumanMessage(content=state['draft'])
146
- ]
147
- response = self.model.invoke(messages)
148
- return {"critique": response.content,
149
- "lnode": "reflect",
150
- "count": 1,
151
- }
152
-
153
- def research_critique_node(self, state: AgentState):
154
- queries = self.model.with_structured_output(Queries).invoke([
155
- SystemMessage(content=self.RESEARCH_CRITIQUE_PROMPT),
156
- HumanMessage(content=state['critique'])
157
- ])
158
- content = state['content'] or []
159
- for q in queries.queries:
160
- response = self.tavily.search(query=q, max_results=2)
161
- for r in response['results']:
162
- content.append(r['content'])
163
- return {"content": content,
164
- "lnode": "research_critique",
165
- "count": 1,
166
- }
167
-
168
- def should_continue(self, state):
169
- if state["revision_number"] > state["max_revisions"]:
170
- return END
171
- return "reflect"
172
-
173
-
174
-
175
- class writer_gui( ):
176
- def __init__(self, graph, share=False):
177
- self.graph = graph
178
- self.share = share
179
- self.partial_message = ""
180
- self.response = {}
181
- self.max_iterations = 10
182
- self.iterations = []
183
- self.threads = []
184
- self.thread_id = -1
185
- self.thread_ts_map = {} # <----- ------>
186
- self.thread = {"configurable": {"thread_id": str(self.thread_id)}}
187
- #self.sdisps = {} #global
188
- self.demo = self.create_interface()
189
-
190
- def run_agent(self, start,topic,stop_after):
191
- #global partial_message, thread_id,thread
192
- #global response, max_iterations, iterations, threads
193
- if start:
194
- self.iterations.append(0)
195
- config = {'task': topic,"max_revisions": 2,"revision_number": 0,
196
- 'lnode': "", 'planner': "no plan", 'draft': "no draft", 'critique': "no critique",
197
- 'content': ["no content",], 'queries': "no queries", 'count':0}
198
- self.thread_id += 1 # new agent, new thread
199
- self.threads.append(self.thread_id)
200
-
201
- # --------++++++>>
202
- if self.thread_id not in self.thread_ts_map:
203
- self.thread_ts_map[self.thread_id] = str(self.thread_id) # or use a stable UUID
204
-
205
- self.thread = {"configurable": {
206
- "thread_id": str(self.thread_id),
207
- "thread_ts": self.thread_ts_map[self.thread_id],
208
- }}
209
- # --------++++++>>
210
-
211
- else:
212
- config = None
213
-
214
- self.thread = {"configurable": {"thread_id": str(self.thread_id)}}
215
-
216
- while self.iterations[self.thread_id] < self.max_iterations:
217
- self.response = self.graph.invoke(config, self.thread)
218
- self.iterations[self.thread_id] += 1
219
- self.partial_message += str(self.response)
220
- self.partial_message += f"\n------------------\n\n"
221
- ## fix
222
- lnode, nnode, _, rev, acount = self.get_disp_state()
223
- yield self.partial_message, lnode, nnode, self.thread_id, rev, acount
224
- config = None #need
225
- #print(f"run_agent:{lnode}")
226
- if not nnode:
227
- #print("Hit the end")
228
- return
229
- if lnode in stop_after:
230
- #print(f"stopping due to stop_after {lnode}")
231
- return
232
- else:
233
- #print(f"Not stopping on lnode {lnode}")
234
- pass
235
- return
236
-
237
- def get_disp_state(self,):
238
- current_state = self.graph.get_state(self.thread)
239
- lnode = current_state.values["lnode"]
240
- acount = current_state.values["count"]
241
- rev = current_state.values["revision_number"]
242
- nnode = current_state.next
243
- #print (lnode,nnode,self.thread_id,rev,acount)
244
- return lnode,nnode,self.thread_id,rev,acount
245
-
246
- def get_state(self,key):
247
- current_values = self.graph.get_state(self.thread)
248
- if key in current_values.values:
249
- lnode,nnode,self.thread_id,rev,astep = self.get_disp_state()
250
- new_label = f"last_node: {lnode}, thread_id: {self.thread_id}, rev: {rev}, step: {astep}"
251
- return gr.update(label=new_label, value=current_values.values[key])
252
- else:
253
- return ""
254
-
255
- def get_content(self,):
256
- current_values = self.graph.get_state(self.thread)
257
- if "content" in current_values.values:
258
- content = current_values.values["content"]
259
- lnode,nnode,thread_id,rev,astep = self.get_disp_state()
260
- new_label = f"last_node: {lnode}, thread_id: {self.thread_id}, rev: {rev}, step: {astep}"
261
- return gr.update(label=new_label, value="\n\n".join(item for item in content) + "\n\n")
262
- else:
263
- return ""
264
-
265
- def update_hist_pd(self,):
266
- #print("update_hist_pd")
267
- hist = []
268
- # curiously, this generator returns the latest first
269
- for state in self.graph.get_state_history(self.thread):
270
- if state.metadata['step'] < 1:
271
- continue
272
- thread_ts = state.config['configurable']['thread_ts']
273
- tid = state.config['configurable']['thread_id']
274
- count = state.values['count']
275
- lnode = state.values['lnode']
276
- rev = state.values['revision_number']
277
- nnode = state.next
278
- st = f"{tid}:{count}:{lnode}:{nnode}:{rev}:{thread_ts}"
279
- hist.append(st)
280
- return gr.Dropdown(label="update_state from: thread:count:last_node:next_node:rev:thread_ts",
281
- choices=hist, value=hist[0],interactive=True)
282
-
283
- def find_config(self,thread_ts):
284
- for state in self.graph.get_state_history(self.thread):
285
- config = state.config
286
- if config['configurable']['thread_ts'] == thread_ts:
287
- return config
288
- return(None)
289
-
290
- def copy_state(self, hist_str):
291
- ''' result of selecting an old state from the step pulldown. Note does not change thread.
292
- This copies an old state to a new current state.
293
- '''
294
- thread_ts = hist_str.split(":")[-1]
295
- #print(f"copy_state from {thread_ts}")
296
- config = self.find_config(thread_ts)
297
- #print(config)
298
- state = self.graph.get_state(config)
299
- self.graph.update_state(self.thread, state.values, as_node=state.values['lnode'])
300
- new_state = self.graph.get_state(self.thread) #should now match
301
- new_thread_ts = new_state.config['configurable']['thread_ts']
302
- tid = new_state.config['configurable']['thread_id']
303
- count = new_state.values['count']
304
- lnode = new_state.values['lnode']
305
- rev = new_state.values['revision_number']
306
- nnode = new_state.next
307
- return lnode,nnode,new_thread_ts,rev,count
308
-
309
- def update_thread_pd(self,):
310
- #print("update_thread_pd")
311
- return gr.Dropdown(label="choose thread", choices=self.threads, value=str(self.thread_id), interactive=True)
312
-
313
- def switch_thread(self,new_thread_id):
314
- #print(f"switch_thread{new_thread_id}")
315
- self.thread = {"configurable": {"thread_id": str(new_thread_id)}}
316
- self.thread_id = new_thread_id
317
- return
318
-
319
- def modify_state(self,key,asnode,new_state):
320
- ''' gets the current state, modifes a single value in the state identified by key, and updates state with it.
321
- note that this will create a new 'current state' node. If you do this multiple times with different keys, it will create
322
- one for each update. Note also that it doesn't resume after the update
323
- '''
324
- current_values = self.graph.get_state(self.thread)
325
- current_values.values[key] = new_state
326
- self.graph.update_state(self.thread, current_values.values,as_node=asnode)
327
- return
328
-
329
- def create_interface(self):
330
- with gr.Blocks(theme=gr.themes.Default(spacing_size='sm',text_size="sm")) as demo:
331
-
332
- def updt_disp():
333
- ''' general update display on state change '''
334
- current_state = self.graph.get_state(self.thread)
335
- hist = []
336
- # curiously, this generator returns the latest first
337
- for state in self.graph.get_state_history(self.thread):
338
- if state.metadata['step'] < 1: #ignore early states
339
- continue
340
- s_thread_ts = state.config['configurable']['thread_ts']
341
- s_tid = state.config['configurable']['thread_id']
342
- s_count = state.values['count']
343
- s_lnode = state.values['lnode']
344
- s_rev = state.values['revision_number']
345
- s_nnode = state.next
346
- st = f"{s_tid}:{s_count}:{s_lnode}:{s_nnode}:{s_rev}:{s_thread_ts}"
347
- hist.append(st)
348
- if not current_state.metadata: #handle init call
349
- return{}
350
- else:
351
- return {
352
- topic_bx : current_state.values["task"],
353
- lnode_bx : current_state.values["lnode"],
354
- count_bx : current_state.values["count"],
355
- revision_bx : current_state.values["revision_number"],
356
- nnode_bx : current_state.next,
357
- threadid_bx : self.thread_id,
358
- thread_pd : gr.Dropdown(label="choose thread", choices=self.threads,
359
- value=str(self.thread_id), interactive=True),
360
- step_pd : gr.Dropdown(label="update_state from: thread:count:last_node:next_node:rev:thread_ts",
361
- choices=hist, value=hist[0], interactive=True),
362
- }
363
-
364
- def get_snapshots():
365
- new_label = f"thread_id: {self.thread_id}, Summary of snapshots"
366
- sstate = ""
367
- for state in self.graph.get_state_history(self.thread):
368
- for key in ['plan', 'draft', 'critique']:
369
- if key in state.values:
370
- state.values[key] = state.values[key][:80] + "..."
371
- if 'content' in state.values:
372
- for i in range(len(state.values['content'])):
373
- state.values['content'][i] = state.values['content'][i][:20] + '...'
374
- if 'writes' in state.metadata:
375
- state.metadata['writes'] = "not shown"
376
- sstate += str(state) + "\n\n"
377
- return gr.update(label=new_label, value=sstate)
378
-
379
- def vary_btn(stat):
380
- #print(f"vary_btn{stat}")
381
- return(gr.update(variant=stat))
382
-
383
- with gr.Tab("Agent"):
384
- with gr.Row():
385
- topic_bx = gr.Textbox(label="Essay Topic", value="Pizza Shop")
386
- gen_btn = gr.Button("Generate Essay", scale=0,min_width=80, variant='primary')
387
- cont_btn = gr.Button("Continue Essay", scale=0,min_width=80)
388
- with gr.Row():
389
- lnode_bx = gr.Textbox(label="last node", min_width=100)
390
- nnode_bx = gr.Textbox(label="next node", min_width=100)
391
- threadid_bx = gr.Textbox(label="Thread", scale=0, min_width=80)
392
- revision_bx = gr.Textbox(label="Draft Rev", scale=0, min_width=80)
393
- count_bx = gr.Textbox(label="count", scale=0, min_width=80)
394
- with gr.Accordion("Manage Agent", open=False):
395
- checks = list(self.graph.nodes.keys())
396
- checks.remove('__start__')
397
- stop_after = gr.CheckboxGroup(checks,label="Interrupt After State", value=checks, scale=0, min_width=400)
398
- with gr.Row():
399
- thread_pd = gr.Dropdown(choices=self.threads,interactive=True, label="select thread", min_width=120, scale=0)
400
- step_pd = gr.Dropdown(choices=['N/A'],interactive=True, label="select step", min_width=160, scale=1)
401
- live = gr.Textbox(label="Live Agent Output", lines=5, max_lines=5)
402
-
403
- # actions
404
- sdisps =[topic_bx,lnode_bx,nnode_bx,threadid_bx,revision_bx,count_bx,step_pd,thread_pd]
405
- thread_pd.input(self.switch_thread, [thread_pd], None).then(
406
- fn=updt_disp, inputs=None, outputs=sdisps)
407
-
408
- step_pd.input(self.copy_state,[step_pd],None).then(
409
- fn=updt_disp, inputs=None, outputs=sdisps)
410
-
411
- gen_btn.click(vary_btn, gr.Number(label="secondary", visible=False), gen_btn).then(
412
- fn=self.run_agent, inputs=[gr.Number(True, visible=False),topic_bx,stop_after], outputs=[live],show_progress=True).then(
413
- fn=updt_disp, inputs=None, outputs=sdisps).then(
414
- vary_btn,gr.Number(label="primary", visible=False), gen_btn).then(
415
- vary_btn,gr.Number(label="primary", visible=False), cont_btn)
416
-
417
- cont_btn.click(vary_btn,gr.Number(label="secondary", visible=False), cont_btn).then(
418
- fn=self.run_agent, inputs=[gr.Number(False, visible=False),topic_bx,stop_after],
419
- outputs=[live]).then(
420
- fn=updt_disp, inputs=None, outputs=sdisps).then(
421
- vary_btn,gr.Number(label="primary", visible=False), cont_btn)
422
-
423
- with gr.Tab("Plan"):
424
- with gr.Row():
425
- refresh_btn = gr.Button("Refresh")
426
- modify_btn = gr.Button("Modify")
427
- plan = gr.Textbox(label="Plan", lines=10, interactive=True)
428
- refresh_btn.click(fn=self.get_state, inputs=gr.Number(label="plan", visible=False), outputs=plan)
429
- modify_btn.click(fn=self.modify_state, inputs=[gr.Number(label="plan", visible=False),
430
- gr.Number(label="planner", visible=False), plan],outputs=None).then(
431
- fn=updt_disp, inputs=None, outputs=sdisps)
432
- with gr.Tab("Research Content"):
433
- refresh_btn = gr.Button("Refresh")
434
- content_bx = gr.Textbox(label="content", lines=10)
435
- refresh_btn.click(fn=self.get_content, inputs=None, outputs=content_bx)
436
- with gr.Tab("Draft"):
437
- with gr.Row():
438
- refresh_btn = gr.Button("Refresh")
439
- modify_btn = gr.Button("Modify")
440
- draft_bx = gr.Textbox(label="draft", lines=10, interactive=True)
441
- refresh_btn.click(fn=self.get_state, inputs=gr.Number(label="draft", visible=False), outputs=draft_bx)
442
- modify_btn.click(fn=self.modify_state, inputs=[gr.Number(label="draft", visible=False),
443
- gr.Number(label="generate", visible=False),
444
- draft_bx], outputs=None).then(
445
- fn=updt_disp, inputs=None, outputs=sdisps)
446
-
447
- with gr.Tab("Critique"):
448
- with gr.Row():
449
- refresh_btn = gr.Button("Refresh")
450
- modify_btn = gr.Button("Modify")
451
- critique_bx = gr.Textbox(label="Critique", lines=10, interactive=True)
452
- refresh_btn.click(fn=self.get_state, inputs=gr.Number(label="critique", visible=False), outputs=critique_bx)
453
- modify_btn.click(fn=self.modify_state, inputs=[gr.Number(label="critique", visible=False),
454
- gr.Number(label="reflect", visible=False),
455
- critique_bx], outputs=None).then(
456
- fn=updt_disp, inputs=None, outputs=sdisps)
457
-
458
- with gr.Tab("StateSnapShots"):
459
- with gr.Row():
460
- refresh_btn = gr.Button("Refresh")
461
- snapshots = gr.Textbox(label="State Snapshots Summaries")
462
- refresh_btn.click(fn=get_snapshots, inputs=None, outputs=snapshots)
463
- return demo
464
-
465
- def launch(self):
466
- if port := os.getenv("PORT1"):
467
- self.demo.launch(share=True, server_port=int(port), server_name="0.0.0.0")
468
- else:
469
- self.demo.launch(share=self.share)
470
-
471
-
472
- if __name__ == "__main__":
473
- agent = ewriter()
474
- gui = writer_gui(agent.graph)
475
- gui.launch()