Junhoee commited on
Commit
c8b5814
·
verified ·
1 Parent(s): 42a5f70

Update megumin_agent/chat.py

Browse files
Files changed (1) hide show
  1. megumin_agent/chat.py +135 -26
megumin_agent/chat.py CHANGED
@@ -2,12 +2,16 @@ from __future__ import annotations
2
 
3
  import uuid
4
  from dataclasses import dataclass
 
 
5
  from typing import Iterable
6
 
7
  from .bootstrap import bootstrap_environment
8
 
9
  bootstrap_environment()
10
 
 
 
11
  from google.adk.runners import Runner
12
  from google.adk.sessions import InMemorySessionService
13
  from google.genai import types
@@ -20,6 +24,9 @@ from .agent import root_agent
20
  APP_NAME = "megumin_rag_app"
21
  MAX_TURNS_IN_CONTEXT = 6
22
  SUMMARY_MAX_CHARS = 800
 
 
 
23
 
24
 
25
  @dataclass
@@ -45,14 +52,78 @@ def _event_texts(events: Iterable) -> list[str]:
45
  return lines
46
 
47
 
48
- def _compress_summary(previous_summary: str, new_lines: list[str]) -> str:
49
- pieces = [previous_summary.strip()] if previous_summary.strip() else []
50
- if new_lines:
51
- pieces.append(" / ".join(new_lines))
52
- summary = " | ".join(piece for piece in pieces if piece).strip()
53
- if len(summary) <= SUMMARY_MAX_CHARS:
54
- return summary
55
- return "..." + summary[-(SUMMARY_MAX_CHARS - 3) :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  def _trim_session_history(
@@ -62,11 +133,7 @@ def _trim_session_history(
62
  session_id: str,
63
  ) -> None:
64
  session_store = services.session_service.sessions
65
- storage_session = (
66
- session_store.get(APP_NAME, {})
67
- .get(user_id, {})
68
- .get(session_id)
69
- )
70
  if storage_session is None:
71
  return
72
 
@@ -76,11 +143,10 @@ def _trim_session_history(
76
 
77
  overflow = storage_session.events[:-max_events]
78
  storage_session.events = storage_session.events[-max_events:]
79
- previous_summary = str(storage_session.state.get("conversation_summary", ""))
80
- storage_session.state["conversation_summary"] = _compress_summary(
81
- previous_summary,
82
- _event_texts(overflow),
83
- )
84
 
85
 
86
  def create_chat_services() -> ChatServices:
@@ -95,14 +161,24 @@ def create_chat_services() -> ChatServices:
95
  return ChatServices(runner=runner, session_service=session_service)
96
 
97
 
98
- async def chat_once(
 
 
 
 
 
 
 
 
 
 
 
99
  user_message: str,
100
  services: ChatServices,
101
  session_id: str | None = None,
102
  user_id: str = "local-user",
103
- ) -> tuple[str, str]:
104
  active_session_id = session_id or str(uuid.uuid4())
105
- last_text = ""
106
  existing_session = await services.session_service.get_session(
107
  app_name=APP_NAME,
108
  user_id=user_id,
@@ -115,17 +191,34 @@ async def chat_once(
115
  session_id=active_session_id,
116
  )
117
 
 
 
 
 
118
  async for event in services.runner.run_async(
119
  user_id=user_id,
120
  session_id=active_session_id,
121
  new_message=types.UserContent(parts=[types.Part(text=user_message)]),
 
122
  ):
123
- if not event.content or not event.content.parts:
 
 
 
 
124
  continue
125
- for part in event.content.parts:
126
- text = getattr(part, "text", None)
127
- if text and event.author != "user":
128
- last_text = text
 
 
 
 
 
 
 
 
129
 
130
  _trim_session_history(
131
  services,
@@ -133,4 +226,20 @@ async def chat_once(
133
  session_id=active_session_id,
134
  )
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  return last_text, active_session_id
 
2
 
3
  import uuid
4
  from dataclasses import dataclass
5
+ from typing import Any
6
+ from typing import AsyncIterator
7
  from typing import Iterable
8
 
9
  from .bootstrap import bootstrap_environment
10
 
11
  bootstrap_environment()
12
 
13
+ from google.adk.agents.run_config import RunConfig
14
+ from google.adk.agents.run_config import StreamingMode
15
  from google.adk.runners import Runner
16
  from google.adk.sessions import InMemorySessionService
17
  from google.genai import types
 
24
  APP_NAME = "megumin_rag_app"
25
  MAX_TURNS_IN_CONTEXT = 6
26
  SUMMARY_MAX_CHARS = 800
27
+ SUMMARY_USER_LIMIT = 3
28
+ SUMMARY_ASSISTANT_LIMIT = 2
29
+ SUMMARY_ITEM_CHARS = 42
30
 
31
 
32
  @dataclass
 
52
  return lines
53
 
54
 
55
+ def _compact_summary_item(text: str, limit: int = SUMMARY_ITEM_CHARS) -> str:
56
+ compact = " ".join(str(text or "").split()).strip()
57
+ if len(compact) <= limit:
58
+ return compact
59
+ return compact[: limit - 3].rstrip() + "..."
60
+
61
+
62
+ def _parse_summary_map(value: Any) -> dict[str, list[str]]:
63
+ if not isinstance(value, dict):
64
+ return {
65
+ "user_topics": [],
66
+ "assistant_points": [],
67
+ }
68
+ return {
69
+ "user_topics": [
70
+ str(item) for item in value.get("user_topics", []) if str(item).strip()
71
+ ],
72
+ "assistant_points": [
73
+ str(item)
74
+ for item in value.get("assistant_points", [])
75
+ if str(item).strip()
76
+ ],
77
+ }
78
+
79
+
80
+ def _merge_unique_tail(previous: list[str], additions: list[str], limit: int) -> list[str]:
81
+ merged: list[str] = []
82
+ for item in [*previous, *additions]:
83
+ if not item or item in merged:
84
+ continue
85
+ merged.append(item)
86
+ return merged[-limit:]
87
+
88
+
89
+ def _compress_summary(
90
+ previous_summary_map: Any,
91
+ new_lines: list[str],
92
+ ) -> dict[str, list[str]]:
93
+ summary_map = _parse_summary_map(previous_summary_map)
94
+ user_lines = [
95
+ _compact_summary_item(line.removeprefix("user:").strip())
96
+ for line in new_lines
97
+ if line.startswith("user:")
98
+ ]
99
+ assistant_lines = [
100
+ _compact_summary_item(line.removeprefix("assistant:").strip())
101
+ for line in new_lines
102
+ if line.startswith("assistant:")
103
+ ]
104
+ summary_map["user_topics"] = _merge_unique_tail(
105
+ summary_map["user_topics"],
106
+ user_lines,
107
+ SUMMARY_USER_LIMIT,
108
+ )
109
+ summary_map["assistant_points"] = _merge_unique_tail(
110
+ summary_map["assistant_points"],
111
+ assistant_lines,
112
+ SUMMARY_ASSISTANT_LIMIT,
113
+ )
114
+ return summary_map
115
+
116
+
117
+ def _render_summary(summary_map: dict[str, list[str]]) -> str:
118
+ chunks: list[str] = []
119
+ if summary_map.get("user_topics"):
120
+ chunks.append("user_topics: " + " ; ".join(summary_map["user_topics"]))
121
+ if summary_map.get("assistant_points"):
122
+ chunks.append("assistant_points: " + " ; ".join(summary_map["assistant_points"]))
123
+ rendered = "\n".join(chunks).strip()
124
+ if len(rendered) <= SUMMARY_MAX_CHARS:
125
+ return rendered
126
+ return rendered[: SUMMARY_MAX_CHARS - 3].rstrip() + "..."
127
 
128
 
129
  def _trim_session_history(
 
133
  session_id: str,
134
  ) -> None:
135
  session_store = services.session_service.sessions
136
+ storage_session = session_store.get(APP_NAME, {}).get(user_id, {}).get(session_id)
 
 
 
 
137
  if storage_session is None:
138
  return
139
 
 
143
 
144
  overflow = storage_session.events[:-max_events]
145
  storage_session.events = storage_session.events[-max_events:]
146
+ previous_summary_map = storage_session.state.get("conversation_summary_map", {})
147
+ summary_map = _compress_summary(previous_summary_map, _event_texts(overflow))
148
+ storage_session.state["conversation_summary_map"] = summary_map
149
+ storage_session.state["conversation_summary"] = _render_summary(summary_map)
 
150
 
151
 
152
  def create_chat_services() -> ChatServices:
 
161
  return ChatServices(runner=runner, session_service=session_service)
162
 
163
 
164
+ def _extract_text(event: Any) -> str:
165
+ if not getattr(event, "content", None) or not getattr(event.content, "parts", None):
166
+ return ""
167
+ texts = [
168
+ getattr(part, "text", "")
169
+ for part in event.content.parts
170
+ if getattr(part, "text", "")
171
+ ]
172
+ return "".join(texts).strip()
173
+
174
+
175
+ async def stream_chat(
176
  user_message: str,
177
  services: ChatServices,
178
  session_id: str | None = None,
179
  user_id: str = "local-user",
180
+ ) -> AsyncIterator[tuple[str, str]]:
181
  active_session_id = session_id or str(uuid.uuid4())
 
182
  existing_session = await services.session_service.get_session(
183
  app_name=APP_NAME,
184
  user_id=user_id,
 
191
  session_id=active_session_id,
192
  )
193
 
194
+ streamed_text = ""
195
+ final_text = ""
196
+ run_config = RunConfig(streaming_mode=StreamingMode.SSE)
197
+
198
  async for event in services.runner.run_async(
199
  user_id=user_id,
200
  session_id=active_session_id,
201
  new_message=types.UserContent(parts=[types.Part(text=user_message)]),
202
+ run_config=run_config,
203
  ):
204
+ if getattr(event, "author", None) == "user":
205
+ continue
206
+
207
+ text = _extract_text(event)
208
+ if not text:
209
  continue
210
+
211
+ if getattr(event, "partial", None) is True:
212
+ streamed_text += text
213
+ yield streamed_text, active_session_id
214
+ continue
215
+
216
+ if getattr(event, "is_final_response", None) and event.is_final_response():
217
+ final_text = text
218
+
219
+ if final_text and final_text != streamed_text:
220
+ streamed_text = final_text
221
+ yield streamed_text, active_session_id
222
 
223
  _trim_session_history(
224
  services,
 
226
  session_id=active_session_id,
227
  )
228
 
229
+
230
+ async def chat_once(
231
+ user_message: str,
232
+ services: ChatServices,
233
+ session_id: str | None = None,
234
+ user_id: str = "local-user",
235
+ ) -> tuple[str, str]:
236
+ last_text = ""
237
+ active_session_id = session_id or str(uuid.uuid4())
238
+ async for chunk_text, active_session_id in stream_chat(
239
+ user_message=user_message,
240
+ services=services,
241
+ session_id=active_session_id,
242
+ user_id=user_id,
243
+ ):
244
+ last_text = chunk_text
245
  return last_text, active_session_id