frdel commited on
Commit
5b177ea
·
1 Parent(s): 5ec935f

Persist chat prototype

Browse files
agent.py CHANGED
@@ -23,19 +23,22 @@ class AgentContext:
23
  _counter: int = 0
24
 
25
  def __init__(
26
- self, config: "AgentConfig", id: str | None = None, agent0: "Agent|None" = None
 
 
27
  ):
28
  # build context
29
  self.id = id or str(uuid.uuid4())
 
30
  self.config = config
31
- self.log = Log.Log()
32
  self.agent0 = agent0 or Agent(0, self.config, self)
33
- self.paused = False
34
- self.streaming_agent: Agent | None = None
35
  self.process: DeferredTask | None = None
36
  AgentContext._counter += 1
37
  self.no = AgentContext._counter
38
-
39
  self._contexts[self.id] = self
40
 
41
  @staticmethod
@@ -317,9 +320,6 @@ class Agent:
317
  agent_response
318
  ) # process tools requested in agent message
319
  if tools_result: # final response of message loop available
320
- await self.call_extensions(
321
- "monologue_end", tools_result=tools_result
322
- ) # call monologue_end extensions
323
  return tools_result # break the execution if the task is done
324
 
325
  # exceptions inside message loop:
@@ -338,6 +338,12 @@ class Agent:
338
  except Exception as e: # Other exception kill the loop
339
  self.handle_critical_exception(e)
340
 
 
 
 
 
 
 
341
  # exceptions outside message loop:
342
  except InterventionException as e:
343
  pass # just start over
@@ -345,6 +351,8 @@ class Agent:
345
  self.handle_critical_exception(e)
346
  finally:
347
  self.context.streaming_agent = None # unset current streamer
 
 
348
 
349
  def handle_critical_exception(self, exception: Exception):
350
  if isinstance(exception, HandledException):
 
23
  _counter: int = 0
24
 
25
  def __init__(
26
+ self, config: "AgentConfig", id: str | None = None, name: str | None = None, agent0: "Agent|None" = None,
27
+ log: Log.Log | None = None,
28
+ paused: bool = False, streaming_agent: "Agent|None" = None,
29
  ):
30
  # build context
31
  self.id = id or str(uuid.uuid4())
32
+ self.name = name
33
  self.config = config
34
+ self.log = log or Log.Log()
35
  self.agent0 = agent0 or Agent(0, self.config, self)
36
+ self.paused = paused
37
+ self.streaming_agent = streaming_agent
38
  self.process: DeferredTask | None = None
39
  AgentContext._counter += 1
40
  self.no = AgentContext._counter
41
+
42
  self._contexts[self.id] = self
43
 
44
  @staticmethod
 
320
  agent_response
321
  ) # process tools requested in agent message
322
  if tools_result: # final response of message loop available
 
 
 
323
  return tools_result # break the execution if the task is done
324
 
325
  # exceptions inside message loop:
 
338
  except Exception as e: # Other exception kill the loop
339
  self.handle_critical_exception(e)
340
 
341
+ finally:
342
+ # call message_loop_end extensions
343
+ await self.call_extensions(
344
+ "message_loop_end", loop_data=loop_data
345
+ )
346
+
347
  # exceptions outside message loop:
348
  except InterventionException as e:
349
  pass # just start over
 
351
  self.handle_critical_exception(e)
352
  finally:
353
  self.context.streaming_agent = None # unset current streamer
354
+ # call monologue_end extensions
355
+ await self.call_extensions("monologue_end", loop_data=loop_data) # type: ignore
356
 
357
  def handle_critical_exception(self, exception: Exception):
358
  if isinstance(exception, HandledException):
python/extensions/{msg_loop_break → message_loop_end}/.gitkeep RENAMED
File without changes
python/extensions/message_loop_end/_90_save_chat.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from python.helpers.extension import Extension
2
+ from agent import LoopData
3
+ from python.helpers import persist_chat
4
+
5
+
6
+ class SaveChat(Extension):
7
+ async def execute(self, loop_data: LoopData = LoopData(), **kwargs):
8
+ persist_chat.save_chat(self.agent.context)
python/extensions/msg_loop_end/.gitkeep DELETED
File without changes
python/helpers/files.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os, re
2
 
3
  import re
@@ -62,6 +63,22 @@ def find_file_in_dirs(file_path, backup_dirs):
62
  def remove_code_fences(text):
63
  return re.sub(r'~~~\w*\n|~~~', '', text)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def get_abs_path(*relative_paths):
67
  return os.path.join(get_base_dir(), *relative_paths)
@@ -70,7 +87,6 @@ def exists(*relative_paths):
70
  path = get_abs_path(*relative_paths)
71
  return os.path.exists(path)
72
 
73
-
74
  def get_base_dir():
75
  # Get the base directory from the current file path
76
  base_dir = os.path.dirname(os.path.abspath(os.path.join(__file__,"../../")))
 
1
+ from fnmatch import fnmatch
2
  import os, re
3
 
4
  import re
 
63
  def remove_code_fences(text):
64
  return re.sub(r'~~~\w*\n|~~~', '', text)
65
 
66
+ def write_file(relative_path:str, content:str):
67
+ abs_path = get_abs_path(relative_path)
68
+ os.makedirs(os.path.dirname(abs_path), exist_ok=True)
69
+ with open(abs_path, 'w') as f:
70
+ f.write(content)
71
+
72
+ def delete_file(relative_path:str):
73
+ abs_path = get_abs_path(relative_path)
74
+ if os.path.exists(abs_path):
75
+ os.remove(abs_path)
76
+
77
+ def list_files(relative_path:str, filter:str="*"):
78
+ abs_path = get_abs_path(relative_path)
79
+ if not os.path.exists(abs_path):
80
+ return []
81
+ return [file for file in os.listdir(abs_path) if fnmatch(file, filter)]
82
 
83
  def get_abs_path(*relative_paths):
84
  return os.path.join(get_base_dir(), *relative_paths)
 
87
  path = get_abs_path(*relative_paths)
88
  return os.path.exists(path)
89
 
 
90
  def get_base_dir():
91
  # Get the base directory from the current file path
92
  base_dir = os.path.dirname(os.path.abspath(os.path.join(__file__,"../../")))
python/helpers/log.py CHANGED
@@ -1,6 +1,6 @@
1
  from dataclasses import dataclass, field
2
  import json
3
- from typing import Literal, Optional, Dict
4
  import uuid
5
  from collections import OrderedDict # Import OrderedDict
6
 
@@ -145,7 +145,7 @@ class Log:
145
 
146
  self.updates += [item.no]
147
 
148
- def output(self, start=None, end=None):
149
  if start is None:
150
  start = 0
151
  if end is None:
 
1
  from dataclasses import dataclass, field
2
  import json
3
+ from typing import Any, Literal, Optional, Dict
4
  import uuid
5
  from collections import OrderedDict # Import OrderedDict
6
 
 
145
 
146
  self.updates += [item.no]
147
 
148
+ def output(self, start=None, end=None):
149
  if start is None:
150
  start = 0
151
  if end is None:
python/helpers/persist_chat.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Any
3
+ import uuid
4
+ from agent import Agent, AgentConfig, AgentContext, HumanMessage, AIMessage
5
+ from python.helpers import files
6
+ import json
7
+ from initialize import initialize
8
+
9
+ from python.helpers.log import Log, LogItem
10
+
11
+ CHATS_FOLDER = "tmp/chats"
12
+ LOG_SIZE = 1000
13
+
14
+
15
+ def save_chat(context: AgentContext):
16
+ relative_path = _get_file_path(context.id)
17
+ data = _serialize_context(context)
18
+ js = _safe_json_serialize(data, ensure_ascii=False)
19
+ files.write_file(relative_path, js)
20
+
21
+
22
+ def load_chats():
23
+ json_files = files.list_files("tmp/chats", "*.json")
24
+ for file in json_files:
25
+ path = files.get_abs_path(CHATS_FOLDER, file)
26
+ js = files.read_file(path)
27
+ data = json.loads(js)
28
+ ctx = _deserialize_context(data)
29
+
30
+
31
+ def remove_chat(ctxid):
32
+ files.delete_file(_get_file_path(ctxid))
33
+
34
+
35
+ def _get_file_path(ctxid: str):
36
+ return f"{CHATS_FOLDER}/{ctxid}.json"
37
+
38
+
39
+ def _serialize_context(context: AgentContext):
40
+ # serialize agents
41
+ agents = []
42
+ agent = context.agent0
43
+ while agent:
44
+ agents.append(_serialize_agent(agent))
45
+ agent = agent.data.get("subordinate", None)
46
+
47
+ return {
48
+ "id": context.id,
49
+ "agents": agents,
50
+ "streaming_agent": (
51
+ context.streaming_agent.number if context.streaming_agent else 0
52
+ ),
53
+ "log": _serialize_log(context.log),
54
+ }
55
+
56
+
57
+ def _serialize_agent(agent: Agent):
58
+ data = {**agent.data}
59
+ if "superior" in data:
60
+ del data["superior"]
61
+ if "subordinate" in data:
62
+ del data["subordinate"]
63
+
64
+ history = []
65
+ for msg in agent.history:
66
+ history.append({"type": msg.type, "content": msg.content})
67
+
68
+ return {
69
+ "number": agent.number,
70
+ "data": data,
71
+ "history": history,
72
+ }
73
+
74
+
75
+ def _serialize_log(log: Log):
76
+ return {
77
+ "guid": log.guid,
78
+ "logs": [item.output() for item in log.logs[-LOG_SIZE:]]
79
+ , # serialize LogItem objects
80
+ "progress": log.progress,
81
+ "progress_no": log.progress_no,
82
+ }
83
+
84
+
85
+ def _deserialize_context(data):
86
+ config = initialize()
87
+ log = _deserialize_log(data.get("log", None))
88
+
89
+ context = AgentContext(
90
+ config=config,
91
+ id=data.get("id", None),
92
+ name=data.get("name", None),
93
+ log=log,
94
+ paused=True,
95
+ # agent0=agent0,
96
+ # streaming_agent=straming_agent,
97
+ )
98
+
99
+ agents = data.get("agents", [])
100
+ agent0 = _deserialize_agents(agents, config, context)
101
+ streaming_agent_no = data.get("streaming_agent", 0)
102
+ straming_agent = (
103
+ agents[streaming_agent_no] if streaming_agent_no < len(agents) else None
104
+ )
105
+
106
+ context.agent0 = agent0
107
+ context.streaming_agent = straming_agent
108
+
109
+ return context
110
+
111
+
112
+ def _deserialize_agents(
113
+ agents: list[dict[str, Any]], config: AgentConfig, context: AgentContext
114
+ ) -> Agent:
115
+ prev: Agent | None = None
116
+ zero: Agent | None = None
117
+
118
+ for ag in agents:
119
+ current = Agent(
120
+ number=ag["number"],
121
+ config=config,
122
+ context=context,
123
+ )
124
+ current.data = ag.get("data", {})
125
+ current.history = _deserialize_history(ag.get("history", []))
126
+
127
+ if not zero:
128
+ zero = current
129
+
130
+ if prev:
131
+ prev.set_data("subordinate", current)
132
+ current.set_data("superior", prev)
133
+
134
+ return zero or Agent(0, config, context)
135
+
136
+
137
+ def _deserialize_history(history: list[dict[str, Any]]):
138
+ result = []
139
+ for hist in history:
140
+ content = hist.get("content", "")
141
+ msg = (
142
+ HumanMessage(content=content)
143
+ if hist.get("type") == "human"
144
+ else AIMessage(content=content)
145
+ )
146
+ result.append(msg)
147
+ return result
148
+
149
+
150
+ def _deserialize_log(data: dict[str, Any]) -> "Log":
151
+ log = Log()
152
+ log.guid = data.get("guid", str(uuid.uuid4()))
153
+ log.progress = data.get("progress", "")
154
+ log.progress_no = data.get("progress_no", 0)
155
+
156
+ # Deserialize the list of LogItem objects
157
+ i = 0
158
+ for item_data in data.get("logs", []):
159
+ log.logs.append(LogItem(
160
+ log=log, # restore the log reference
161
+ no=item_data["no"],
162
+ type=item_data["type"],
163
+ heading=item_data.get("heading", ""),
164
+ content=item_data.get("content", ""),
165
+ kvps=OrderedDict(item_data["kvps"]) if item_data["kvps"] else None,
166
+ temp=item_data.get("temp", False),
167
+ ))
168
+ log.updates.append(i)
169
+ i += 1
170
+
171
+ return log
172
+
173
+
174
+ def _safe_json_serialize(obj, **kwargs):
175
+ def serializer(o):
176
+ if isinstance(o, dict):
177
+ return {k: v for k, v in o.items() if is_json_serializable(v)}
178
+ elif isinstance(o, (list, tuple)):
179
+ return [item for item in o if is_json_serializable(item)]
180
+ elif is_json_serializable(o):
181
+ return o
182
+ else:
183
+ return None # Skip this property
184
+
185
+ def is_json_serializable(item):
186
+ try:
187
+ json.dumps(item)
188
+ return True
189
+ except (TypeError, OverflowError):
190
+ return False
191
+
192
+ return json.dumps(obj, default=serializer, **kwargs)
run_ui.py CHANGED
@@ -8,9 +8,11 @@ from flask import Flask, request, jsonify, Response
8
  from flask_basicauth import BasicAuth
9
  from agent import AgentContext
10
  from initialize import initialize
 
11
  from python.helpers.files import get_abs_path
12
  from python.helpers.print_style import PrintStyle
13
  from python.helpers.dotenv import load_dotenv
 
14
 
15
 
16
  # initialize the internal Flask server
@@ -119,6 +121,7 @@ async def handle_message(sync: bool):
119
  response = {
120
  "ok": True,
121
  "message": result,
 
122
  }
123
  else:
124
 
@@ -126,6 +129,7 @@ async def handle_message(sync: bool):
126
  response = {
127
  "ok": True,
128
  "message": "Message received.",
 
129
  }
130
 
131
  except Exception as e:
@@ -183,6 +187,7 @@ async def reset():
183
  # context instance - get or create
184
  context = get_context(ctxid)
185
  context.reset()
 
186
 
187
  response = {
188
  "ok": True,
@@ -211,6 +216,7 @@ async def remove():
211
 
212
  # context instance - get or create
213
  AgentContext.remove(ctxid)
 
214
 
215
  response = {
216
  "ok": True,
@@ -235,7 +241,7 @@ async def poll():
235
 
236
  # data sent to the server
237
  input = request.get_json()
238
- ctxid = input.get("context", uuid.uuid4())
239
  from_no = input.get("log_from", 0)
240
 
241
  # context instance - get or create
@@ -286,6 +292,9 @@ def run():
286
 
287
  #load env vars
288
  load_dotenv()
 
 
 
289
 
290
  # Suppress only request logs but keep the startup messages
291
  from werkzeug.serving import WSGIRequestHandler
 
8
  from flask_basicauth import BasicAuth
9
  from agent import AgentContext
10
  from initialize import initialize
11
+ from python.helpers import files
12
  from python.helpers.files import get_abs_path
13
  from python.helpers.print_style import PrintStyle
14
  from python.helpers.dotenv import load_dotenv
15
+ from python.helpers import persist_chat
16
 
17
 
18
  # initialize the internal Flask server
 
121
  response = {
122
  "ok": True,
123
  "message": result,
124
+ "context": context.id,
125
  }
126
  else:
127
 
 
129
  response = {
130
  "ok": True,
131
  "message": "Message received.",
132
+ "context": context.id,
133
  }
134
 
135
  except Exception as e:
 
187
  # context instance - get or create
188
  context = get_context(ctxid)
189
  context.reset()
190
+ persist_chat.save_chat(context)
191
 
192
  response = {
193
  "ok": True,
 
216
 
217
  # context instance - get or create
218
  AgentContext.remove(ctxid)
219
+ persist_chat.remove_chat(ctxid)
220
 
221
  response = {
222
  "ok": True,
 
241
 
242
  # data sent to the server
243
  input = request.get_json()
244
+ ctxid = input.get("context", None)
245
  from_no = input.get("log_from", 0)
246
 
247
  # context instance - get or create
 
292
 
293
  #load env vars
294
  load_dotenv()
295
+
296
+ # initialize contexts from persisted chats
297
+ persist_chat.load_chats()
298
 
299
  # Suppress only request logs but keep the startup messages
300
  from werkzeug.serving import WSGIRequestHandler
webui/index.js CHANGED
@@ -73,6 +73,8 @@ async function sendMessage() {
73
  } else {
74
  toast("Undefined error.", "error")
75
  }
 
 
76
  }
77
 
78
  //setMessage('user', message);
@@ -186,7 +188,8 @@ async function poll() {
186
 
187
  if (response.ok) {
188
 
189
- setContext(response.context)
 
190
 
191
  if (lastLogGuid != response.log_guid) {
192
  chatHistory.innerHTML = ""
@@ -408,7 +411,6 @@ function scrollChanged(isAtBottom) {
408
  const inputAS = Alpine.$data(autoScrollSwitch);
409
  inputAS.autoScroll = isAtBottom
410
  // autoScrollSwitch.checked = isAtBottom
411
- console.log(isAtBottom)
412
  }
413
 
414
  function updateAfterScroll() {
 
73
  } else {
74
  toast("Undefined error.", "error")
75
  }
76
+ } else {
77
+ setContext(response.context)
78
  }
79
 
80
  //setMessage('user', message);
 
188
 
189
  if (response.ok) {
190
 
191
+ if (!context) setContext(response.context)
192
+ if (response.context != context) return //skip late polls after context change
193
 
194
  if (lastLogGuid != response.log_guid) {
195
  chatHistory.innerHTML = ""
 
411
  const inputAS = Alpine.$data(autoScrollSwitch);
412
  inputAS.autoScroll = isAtBottom
413
  // autoScrollSwitch.checked = isAtBottom
 
414
  }
415
 
416
  function updateAfterScroll() {