Spaces:
Sleeping
Sleeping
Update megumin_agent/chat.py
Browse files- megumin_agent/chat.py +135 -26
megumin_agent/chat.py
CHANGED
|
@@ -2,12 +2,16 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import uuid
|
| 4 |
from dataclasses import dataclass
|
|
|
|
|
|
|
| 5 |
from typing import Iterable
|
| 6 |
|
| 7 |
from .bootstrap import bootstrap_environment
|
| 8 |
|
| 9 |
bootstrap_environment()
|
| 10 |
|
|
|
|
|
|
|
| 11 |
from google.adk.runners import Runner
|
| 12 |
from google.adk.sessions import InMemorySessionService
|
| 13 |
from google.genai import types
|
|
@@ -20,6 +24,9 @@ from .agent import root_agent
|
|
| 20 |
APP_NAME = "megumin_rag_app"
|
| 21 |
MAX_TURNS_IN_CONTEXT = 6
|
| 22 |
SUMMARY_MAX_CHARS = 800
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
@dataclass
|
|
@@ -45,14 +52,78 @@ def _event_texts(events: Iterable) -> list[str]:
|
|
| 45 |
return lines
|
| 46 |
|
| 47 |
|
| 48 |
-
def
|
| 49 |
-
|
| 50 |
-
if
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
def _trim_session_history(
|
|
@@ -62,11 +133,7 @@ def _trim_session_history(
|
|
| 62 |
session_id: str,
|
| 63 |
) -> None:
|
| 64 |
session_store = services.session_service.sessions
|
| 65 |
-
storage_session = (
|
| 66 |
-
session_store.get(APP_NAME, {})
|
| 67 |
-
.get(user_id, {})
|
| 68 |
-
.get(session_id)
|
| 69 |
-
)
|
| 70 |
if storage_session is None:
|
| 71 |
return
|
| 72 |
|
|
@@ -76,11 +143,10 @@ def _trim_session_history(
|
|
| 76 |
|
| 77 |
overflow = storage_session.events[:-max_events]
|
| 78 |
storage_session.events = storage_session.events[-max_events:]
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
)
|
| 84 |
|
| 85 |
|
| 86 |
def create_chat_services() -> ChatServices:
|
|
@@ -95,14 +161,24 @@ def create_chat_services() -> ChatServices:
|
|
| 95 |
return ChatServices(runner=runner, session_service=session_service)
|
| 96 |
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
user_message: str,
|
| 100 |
services: ChatServices,
|
| 101 |
session_id: str | None = None,
|
| 102 |
user_id: str = "local-user",
|
| 103 |
-
) -> tuple[str, str]:
|
| 104 |
active_session_id = session_id or str(uuid.uuid4())
|
| 105 |
-
last_text = ""
|
| 106 |
existing_session = await services.session_service.get_session(
|
| 107 |
app_name=APP_NAME,
|
| 108 |
user_id=user_id,
|
|
@@ -115,17 +191,34 @@ async def chat_once(
|
|
| 115 |
session_id=active_session_id,
|
| 116 |
)
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
async for event in services.runner.run_async(
|
| 119 |
user_id=user_id,
|
| 120 |
session_id=active_session_id,
|
| 121 |
new_message=types.UserContent(parts=[types.Part(text=user_message)]),
|
|
|
|
| 122 |
):
|
| 123 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
continue
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
_trim_session_history(
|
| 131 |
services,
|
|
@@ -133,4 +226,20 @@ async def chat_once(
|
|
| 133 |
session_id=active_session_id,
|
| 134 |
)
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
return last_text, active_session_id
|
|
|
|
| 2 |
|
| 3 |
import uuid
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
from typing import Any
|
| 6 |
+
from typing import AsyncIterator
|
| 7 |
from typing import Iterable
|
| 8 |
|
| 9 |
from .bootstrap import bootstrap_environment
|
| 10 |
|
| 11 |
bootstrap_environment()
|
| 12 |
|
| 13 |
+
from google.adk.agents.run_config import RunConfig
|
| 14 |
+
from google.adk.agents.run_config import StreamingMode
|
| 15 |
from google.adk.runners import Runner
|
| 16 |
from google.adk.sessions import InMemorySessionService
|
| 17 |
from google.genai import types
|
|
|
|
| 24 |
APP_NAME = "megumin_rag_app"
|
| 25 |
MAX_TURNS_IN_CONTEXT = 6
|
| 26 |
SUMMARY_MAX_CHARS = 800
|
| 27 |
+
SUMMARY_USER_LIMIT = 3
|
| 28 |
+
SUMMARY_ASSISTANT_LIMIT = 2
|
| 29 |
+
SUMMARY_ITEM_CHARS = 42
|
| 30 |
|
| 31 |
|
| 32 |
@dataclass
|
|
|
|
| 52 |
return lines
|
| 53 |
|
| 54 |
|
| 55 |
+
def _compact_summary_item(text: str, limit: int = SUMMARY_ITEM_CHARS) -> str:
|
| 56 |
+
compact = " ".join(str(text or "").split()).strip()
|
| 57 |
+
if len(compact) <= limit:
|
| 58 |
+
return compact
|
| 59 |
+
return compact[: limit - 3].rstrip() + "..."
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _parse_summary_map(value: Any) -> dict[str, list[str]]:
|
| 63 |
+
if not isinstance(value, dict):
|
| 64 |
+
return {
|
| 65 |
+
"user_topics": [],
|
| 66 |
+
"assistant_points": [],
|
| 67 |
+
}
|
| 68 |
+
return {
|
| 69 |
+
"user_topics": [
|
| 70 |
+
str(item) for item in value.get("user_topics", []) if str(item).strip()
|
| 71 |
+
],
|
| 72 |
+
"assistant_points": [
|
| 73 |
+
str(item)
|
| 74 |
+
for item in value.get("assistant_points", [])
|
| 75 |
+
if str(item).strip()
|
| 76 |
+
],
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _merge_unique_tail(previous: list[str], additions: list[str], limit: int) -> list[str]:
|
| 81 |
+
merged: list[str] = []
|
| 82 |
+
for item in [*previous, *additions]:
|
| 83 |
+
if not item or item in merged:
|
| 84 |
+
continue
|
| 85 |
+
merged.append(item)
|
| 86 |
+
return merged[-limit:]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _compress_summary(
|
| 90 |
+
previous_summary_map: Any,
|
| 91 |
+
new_lines: list[str],
|
| 92 |
+
) -> dict[str, list[str]]:
|
| 93 |
+
summary_map = _parse_summary_map(previous_summary_map)
|
| 94 |
+
user_lines = [
|
| 95 |
+
_compact_summary_item(line.removeprefix("user:").strip())
|
| 96 |
+
for line in new_lines
|
| 97 |
+
if line.startswith("user:")
|
| 98 |
+
]
|
| 99 |
+
assistant_lines = [
|
| 100 |
+
_compact_summary_item(line.removeprefix("assistant:").strip())
|
| 101 |
+
for line in new_lines
|
| 102 |
+
if line.startswith("assistant:")
|
| 103 |
+
]
|
| 104 |
+
summary_map["user_topics"] = _merge_unique_tail(
|
| 105 |
+
summary_map["user_topics"],
|
| 106 |
+
user_lines,
|
| 107 |
+
SUMMARY_USER_LIMIT,
|
| 108 |
+
)
|
| 109 |
+
summary_map["assistant_points"] = _merge_unique_tail(
|
| 110 |
+
summary_map["assistant_points"],
|
| 111 |
+
assistant_lines,
|
| 112 |
+
SUMMARY_ASSISTANT_LIMIT,
|
| 113 |
+
)
|
| 114 |
+
return summary_map
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _render_summary(summary_map: dict[str, list[str]]) -> str:
|
| 118 |
+
chunks: list[str] = []
|
| 119 |
+
if summary_map.get("user_topics"):
|
| 120 |
+
chunks.append("user_topics: " + " ; ".join(summary_map["user_topics"]))
|
| 121 |
+
if summary_map.get("assistant_points"):
|
| 122 |
+
chunks.append("assistant_points: " + " ; ".join(summary_map["assistant_points"]))
|
| 123 |
+
rendered = "\n".join(chunks).strip()
|
| 124 |
+
if len(rendered) <= SUMMARY_MAX_CHARS:
|
| 125 |
+
return rendered
|
| 126 |
+
return rendered[: SUMMARY_MAX_CHARS - 3].rstrip() + "..."
|
| 127 |
|
| 128 |
|
| 129 |
def _trim_session_history(
|
|
|
|
| 133 |
session_id: str,
|
| 134 |
) -> None:
|
| 135 |
session_store = services.session_service.sessions
|
| 136 |
+
storage_session = session_store.get(APP_NAME, {}).get(user_id, {}).get(session_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
if storage_session is None:
|
| 138 |
return
|
| 139 |
|
|
|
|
| 143 |
|
| 144 |
overflow = storage_session.events[:-max_events]
|
| 145 |
storage_session.events = storage_session.events[-max_events:]
|
| 146 |
+
previous_summary_map = storage_session.state.get("conversation_summary_map", {})
|
| 147 |
+
summary_map = _compress_summary(previous_summary_map, _event_texts(overflow))
|
| 148 |
+
storage_session.state["conversation_summary_map"] = summary_map
|
| 149 |
+
storage_session.state["conversation_summary"] = _render_summary(summary_map)
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
def create_chat_services() -> ChatServices:
|
|
|
|
| 161 |
return ChatServices(runner=runner, session_service=session_service)
|
| 162 |
|
| 163 |
|
| 164 |
+
def _extract_text(event: Any) -> str:
|
| 165 |
+
if not getattr(event, "content", None) or not getattr(event.content, "parts", None):
|
| 166 |
+
return ""
|
| 167 |
+
texts = [
|
| 168 |
+
getattr(part, "text", "")
|
| 169 |
+
for part in event.content.parts
|
| 170 |
+
if getattr(part, "text", "")
|
| 171 |
+
]
|
| 172 |
+
return "".join(texts).strip()
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
async def stream_chat(
|
| 176 |
user_message: str,
|
| 177 |
services: ChatServices,
|
| 178 |
session_id: str | None = None,
|
| 179 |
user_id: str = "local-user",
|
| 180 |
+
) -> AsyncIterator[tuple[str, str]]:
|
| 181 |
active_session_id = session_id or str(uuid.uuid4())
|
|
|
|
| 182 |
existing_session = await services.session_service.get_session(
|
| 183 |
app_name=APP_NAME,
|
| 184 |
user_id=user_id,
|
|
|
|
| 191 |
session_id=active_session_id,
|
| 192 |
)
|
| 193 |
|
| 194 |
+
streamed_text = ""
|
| 195 |
+
final_text = ""
|
| 196 |
+
run_config = RunConfig(streaming_mode=StreamingMode.SSE)
|
| 197 |
+
|
| 198 |
async for event in services.runner.run_async(
|
| 199 |
user_id=user_id,
|
| 200 |
session_id=active_session_id,
|
| 201 |
new_message=types.UserContent(parts=[types.Part(text=user_message)]),
|
| 202 |
+
run_config=run_config,
|
| 203 |
):
|
| 204 |
+
if getattr(event, "author", None) == "user":
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
text = _extract_text(event)
|
| 208 |
+
if not text:
|
| 209 |
continue
|
| 210 |
+
|
| 211 |
+
if getattr(event, "partial", None) is True:
|
| 212 |
+
streamed_text += text
|
| 213 |
+
yield streamed_text, active_session_id
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
if getattr(event, "is_final_response", None) and event.is_final_response():
|
| 217 |
+
final_text = text
|
| 218 |
+
|
| 219 |
+
if final_text and final_text != streamed_text:
|
| 220 |
+
streamed_text = final_text
|
| 221 |
+
yield streamed_text, active_session_id
|
| 222 |
|
| 223 |
_trim_session_history(
|
| 224 |
services,
|
|
|
|
| 226 |
session_id=active_session_id,
|
| 227 |
)
|
| 228 |
|
| 229 |
+
|
| 230 |
+
async def chat_once(
|
| 231 |
+
user_message: str,
|
| 232 |
+
services: ChatServices,
|
| 233 |
+
session_id: str | None = None,
|
| 234 |
+
user_id: str = "local-user",
|
| 235 |
+
) -> tuple[str, str]:
|
| 236 |
+
last_text = ""
|
| 237 |
+
active_session_id = session_id or str(uuid.uuid4())
|
| 238 |
+
async for chunk_text, active_session_id in stream_chat(
|
| 239 |
+
user_message=user_message,
|
| 240 |
+
services=services,
|
| 241 |
+
session_id=active_session_id,
|
| 242 |
+
user_id=user_id,
|
| 243 |
+
):
|
| 244 |
+
last_text = chunk_text
|
| 245 |
return last_text, active_session_id
|