chatbot-rag-fi / src /agents /services.py
ABAO77's picture
Upload 147 files
0df80b4 verified
from __future__ import annotations
import asyncio
import logging
from collections.abc import AsyncGenerator
from time import perf_counter
from typing import Any
import json
import uuid
from agents import Agent, ModelSettings, Runner, RunState
from agents.items import ToolCallItem, ToolCallOutputItem
from openai.types.responses import ResponseTextDeltaEvent
from src.agents.state import AgentContext, AgentRunResult
from src.agents.flow import run_guardrail
from src.utils.message_builder import MessageBuilder
from src.agents.prompts import get_prompt_bundle
from src.utils.tool_event_inspector import ToolEventInspector
from src.agents.tools import hand_off_ceo, retrieve_brand_context
from src.utils.agent_utils import (
insufficiency_fallback,
input_guardrail_fallback,
system_error_fallback,
)
from src.schemas import ChatHistoryMessage, ChatTextSegment
from src.services.citations import CitationTagStreamFilter, parse_citation_segments
from src.services.llm import get_chat_model
logger = logging.getLogger(__name__)
class AgentService:
def __init__(self) -> None:
self._assistant_agent: Agent[AgentContext] | None = None
@property
def assistant_agent(self) -> Agent[AgentContext]:
if self._assistant_agent is None:
bundle = get_prompt_bundle()
self._assistant_agent = Agent(
name="brand-assistant",
instructions=bundle.system_prompt,
model=get_chat_model(),
model_settings=ModelSettings(
parallel_tool_calls=False,
),
tools=[retrieve_brand_context, hand_off_ceo],
)
return self._assistant_agent
def build_context(self, question: str, history: list[ChatHistoryMessage]) -> AgentContext:
return AgentContext(
question=question,
message_count=len(history) + 1,
prompt_bundle=get_prompt_bundle(),
)
def build_messages(self, history: list[ChatHistoryMessage], question: str) -> list[dict[str, str]]:
return MessageBuilder.build_input_items(history, question)
@staticmethod
def _normalize_resume_state_payload(node: Any) -> None:
"""Normalize persisted message parts for SDK resume compatibility."""
if isinstance(node, dict):
# Some persisted chat-completions items include a message envelope like:
# {"role":"assistant","content":[{"type":"text","text":"..."}],"status":"completed"}
# Normalize it into the simpler structure expected by the converter.
role = node.get("role")
content = node.get("content")
if role in {"assistant", "user", "system"} and isinstance(content, list):
flattened_parts: list[str] = []
for part in content:
if isinstance(part, dict) and part.get("type") in {"text", "output_text"}:
text_value = part.get("text")
if isinstance(text_value, str):
flattened_parts.append(text_value)
if flattened_parts:
node["content"] = "".join(flattened_parts)
node.pop("status", None)
if node.get("type") == "output_text" and "text" in node:
node["type"] = "text"
for value in node.values():
AgentService._normalize_resume_state_payload(value)
return
if isinstance(node, list):
for value in node:
AgentService._normalize_resume_state_payload(value)
async def _load_resume_state(self, state_json: str) -> RunState:
state_dict = json.loads(state_json)
self._normalize_resume_state_payload(state_dict)
state = await RunState.from_json(
self.assistant_agent,
state_dict,
context_deserializer=lambda x: AgentContext(**x),
)
if isinstance(state._context.context.citation_ids, list):
state._context.context.citation_ids = set(state._context.context.citation_ids)
return state
def _build_result(
self,
*,
question: str,
context: AgentContext,
raw_output: str,
ttft_ms: int | None,
latency_ms: int,
fallback_answer: str | None = None,
) -> AgentRunResult:
content = raw_output.strip()
segments = []
citations = []
if fallback_answer:
content = fallback_answer
elif context.retrieval_status == "insufficient":
content = content or insufficiency_fallback()
elif not content:
content = system_error_fallback()
if not fallback_answer and (context.citation_ids or "<doc-ref" in content):
parsed = parse_citation_segments(content, allowed_document_ids=context.citation_ids)
content = parsed.content
segments = parsed.segments
citations = parsed.citations
if fallback_answer:
citations = []
segments = []
else:
if content and not segments:
segments = [ChatTextSegment(text=content)]
return AgentRunResult(
content=content,
segments=segments,
citations=citations,
should_handoff=context.should_handoff,
fallback_answer=fallback_answer,
ttft_ms=ttft_ms,
latency_ms=latency_ms,
contact=context.contact,
email_notification=context.email_notification,
)
async def stream(
self,
question: str,
history: list[ChatHistoryMessage],
*,
conversation_id: str | None = None,
resume_data: dict | None = None,
) -> AsyncGenerator[tuple[str, dict[str, Any]], None]:
resume_conversation_id = resume_data.get("conversation_id") if isinstance(resume_data, dict) else None
resolved_conversation_id = conversation_id or resume_conversation_id or str(uuid.uuid4())
if resume_data:
state = await self._load_resume_state(resume_data["state_json"])
context = state._context.context
context.user_email = resume_data.get("user_email")
context.user_name = resume_data.get("user_name")
context.user_phone = resume_data.get("user_phone")
for item in state.get_interruptions():
if item.name == "hand_off_ceo":
state.approve(item)
payload = self.build_messages(history, question)
else:
context = self.build_context(question, history)
payload = self.build_messages(history, question)
started_at = perf_counter()
ttft_ms: int | None = None
# Fire guardrail and assistant in parallel
guardrail_task = asyncio.create_task(
run_guardrail(question)
)
filter_state = CitationTagStreamFilter()
tool_calls: dict[str, str] = {}
token_buffer: list[str] = []
pending_events: list[tuple[str, dict[str, Any]]] = []
final_output = ""
streamed_visible = False
ceo_notification_emitted = False
guardrail_resolved = False
is_blocked = False
result = None
try:
if resume_data:
result = Runner.run_streamed(self.assistant_agent, state)
else:
result = Runner.run_streamed(
self.assistant_agent,
input=payload,
context=context,
max_turns=6,
)
async for event in result.stream_events():
# Non-blocking poll: has guardrail resolved?
if not guardrail_resolved and guardrail_task.done():
guardrail_resolved = True
is_blocked = guardrail_task.result()
if is_blocked:
result.cancel()
break
# Flush buffered status events then tokens
for evt_type, evt_data in pending_events:
yield evt_type, evt_data
pending_events.clear()
if token_buffer:
ttft_ms = max(1, round((perf_counter() - started_at) * 1000))
yield "perf", {"ttft_ms": ttft_ms}
for tok in token_buffer:
streamed_visible = True
yield "token", {"delta": tok}
token_buffer.clear()
# Process event
if event.type == "run_item_stream_event":
if event.name == "tool_called" and isinstance(event.item, ToolCallItem):
tool_name = ToolEventInspector.tool_name(event.item)
tool_call_id = ToolEventInspector.tool_call_id_from_call(event.item)
if tool_name and tool_call_id:
tool_calls[tool_call_id] = tool_name
if tool_name == "retrieve_brand_context":
evt: tuple[str, dict[str, Any]] = ("status", {"stage": "retrieval_start"})
if guardrail_resolved:
yield evt[0], evt[1]
else:
pending_events.append(evt)
continue
if event.name == "tool_output" and isinstance(event.item, ToolCallOutputItem):
tool_call_id = ToolEventInspector.tool_call_id_from_output(event.item)
tool_name = tool_calls.get(tool_call_id or "")
if tool_name == "retrieve_brand_context":
evt = (
"status",
{
"stage": "retrieval_end",
"sources": [item.model_dump() for item in context.citations],
"handoff": context.should_handoff,
"contact": context.contact.model_dump() if context.contact else None,
},
)
if guardrail_resolved:
yield evt[0], evt[1]
else:
pending_events.append(evt)
if tool_name == "hand_off_ceo" and context.email_notification and not ceo_notification_emitted:
ceo_notification_emitted = True
evt = (
"ceo_email_sent",
{
"contact": context.contact.model_dump() if context.contact else None,
"email_notification": context.email_notification.model_dump(),
},
)
if guardrail_resolved:
yield evt[0], evt[1]
else:
pending_events.append(evt)
continue
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
visible_delta = filter_state.feed(event.data.delta)
if not visible_delta:
continue
if guardrail_resolved:
if ttft_ms is None:
ttft_ms = max(1, round((perf_counter() - started_at) * 1000))
yield "perf", {"ttft_ms": ttft_ms}
streamed_visible = True
yield "token", {"delta": visible_delta}
else:
token_buffer.append(visible_delta)
# Assistant finished — if guardrail hasn't resolved yet, await it now
if not guardrail_resolved:
is_blocked = await guardrail_task
guardrail_resolved = True
if not is_blocked:
# Flush buffered events and tokens
for evt_type, evt_data in pending_events:
yield evt_type, evt_data
if token_buffer:
if ttft_ms is None:
ttft_ms = max(1, round((perf_counter() - started_at) * 1000))
yield "perf", {"ttft_ms": ttft_ms}
for tok in token_buffer:
streamed_visible = True
yield "token", {"delta": tok}
if is_blocked:
fallback = input_guardrail_fallback()
latency_ms = max(1, round((perf_counter() - started_at) * 1000))
if ttft_ms is None:
ttft_ms = latency_ms
yield "perf", {"ttft_ms": ttft_ms}
yield "token", {"delta": fallback}
run_result = self._build_result(
question=question,
context=context,
raw_output="",
ttft_ms=ttft_ms,
latency_ms=latency_ms,
fallback_answer=fallback,
)
else:
if result and result.interruptions:
state_json_dict = result.to_state().to_json()
yield "interrupt", {
"conversation_id": resolved_conversation_id,
"state_json": json.dumps(state_json_dict, default=lambda x: list(x) if isinstance(x, set) else x),
"interruptions": [
{"name": i.name, "arguments": i.arguments} for i in result.interruptions
]
}
return
final_output = str(result.final_output or "") + filter_state.flush()
latency_ms = max(1, round((perf_counter() - started_at) * 1000))
run_result = self._build_result(
question=question,
context=context,
raw_output=final_output,
ttft_ms=ttft_ms,
latency_ms=latency_ms,
)
except Exception:
logger.exception("Agent streaming failed")
guardrail_task.cancel()
if result is not None:
result.cancel()
fallback = system_error_fallback()
latency_ms = max(1, round((perf_counter() - started_at) * 1000))
if not streamed_visible:
if ttft_ms is None:
ttft_ms = latency_ms
yield "perf", {"ttft_ms": ttft_ms}
yield "token", {"delta": fallback}
run_result = self._build_result(
question=question,
context=context,
raw_output="",
ttft_ms=ttft_ms,
latency_ms=latency_ms,
fallback_answer=fallback,
)
yield "done", {
"conversation_id": resolved_conversation_id,
"content": run_result.content,
"segments": [segment.model_dump() for segment in run_result.segments],
"citations": [item.model_dump() for item in run_result.citations],
"handoff": run_result.should_handoff,
"contact": run_result.contact.model_dump() if run_result.contact else None,
"email_notification": run_result.email_notification.model_dump() if run_result.email_notification else None,
"ttft_ms": run_result.ttft_ms,
"latency_ms": run_result.latency_ms,
}