from __future__ import annotations import asyncio import logging from collections.abc import AsyncGenerator from time import perf_counter from typing import Any import json import uuid from agents import Agent, ModelSettings, Runner, RunState from agents.items import ToolCallItem, ToolCallOutputItem from openai.types.responses import ResponseTextDeltaEvent from src.agents.state import AgentContext, AgentRunResult from src.agents.flow import run_guardrail from src.utils.message_builder import MessageBuilder from src.agents.prompts import get_prompt_bundle from src.utils.tool_event_inspector import ToolEventInspector from src.agents.tools import hand_off_ceo, retrieve_brand_context from src.utils.agent_utils import ( insufficiency_fallback, input_guardrail_fallback, system_error_fallback, ) from src.schemas import ChatHistoryMessage, ChatTextSegment from src.services.citations import CitationTagStreamFilter, parse_citation_segments from src.services.llm import get_chat_model logger = logging.getLogger(__name__) class AgentService: def __init__(self) -> None: self._assistant_agent: Agent[AgentContext] | None = None @property def assistant_agent(self) -> Agent[AgentContext]: if self._assistant_agent is None: bundle = get_prompt_bundle() self._assistant_agent = Agent( name="brand-assistant", instructions=bundle.system_prompt, model=get_chat_model(), model_settings=ModelSettings( parallel_tool_calls=False, ), tools=[retrieve_brand_context, hand_off_ceo], ) return self._assistant_agent def build_context(self, question: str, history: list[ChatHistoryMessage]) -> AgentContext: return AgentContext( question=question, message_count=len(history) + 1, prompt_bundle=get_prompt_bundle(), ) def build_messages(self, history: list[ChatHistoryMessage], question: str) -> list[dict[str, str]]: return MessageBuilder.build_input_items(history, question) @staticmethod def _normalize_resume_state_payload(node: Any) -> None: """Normalize persisted message parts for SDK resume compatibility.""" if isinstance(node, dict): # Some persisted chat-completions items include a message envelope like: # {"role":"assistant","content":[{"type":"text","text":"..."}],"status":"completed"} # Normalize it into the simpler structure expected by the converter. role = node.get("role") content = node.get("content") if role in {"assistant", "user", "system"} and isinstance(content, list): flattened_parts: list[str] = [] for part in content: if isinstance(part, dict) and part.get("type") in {"text", "output_text"}: text_value = part.get("text") if isinstance(text_value, str): flattened_parts.append(text_value) if flattened_parts: node["content"] = "".join(flattened_parts) node.pop("status", None) if node.get("type") == "output_text" and "text" in node: node["type"] = "text" for value in node.values(): AgentService._normalize_resume_state_payload(value) return if isinstance(node, list): for value in node: AgentService._normalize_resume_state_payload(value) async def _load_resume_state(self, state_json: str) -> RunState: state_dict = json.loads(state_json) self._normalize_resume_state_payload(state_dict) state = await RunState.from_json( self.assistant_agent, state_dict, context_deserializer=lambda x: AgentContext(**x), ) if isinstance(state._context.context.citation_ids, list): state._context.context.citation_ids = set(state._context.context.citation_ids) return state def _build_result( self, *, question: str, context: AgentContext, raw_output: str, ttft_ms: int | None, latency_ms: int, fallback_answer: str | None = None, ) -> AgentRunResult: content = raw_output.strip() segments = [] citations = [] if fallback_answer: content = fallback_answer elif context.retrieval_status == "insufficient": content = content or insufficiency_fallback() elif not content: content = system_error_fallback() if not fallback_answer and (context.citation_ids or " AsyncGenerator[tuple[str, dict[str, Any]], None]: resume_conversation_id = resume_data.get("conversation_id") if isinstance(resume_data, dict) else None resolved_conversation_id = conversation_id or resume_conversation_id or str(uuid.uuid4()) if resume_data: state = await self._load_resume_state(resume_data["state_json"]) context = state._context.context context.user_email = resume_data.get("user_email") context.user_name = resume_data.get("user_name") context.user_phone = resume_data.get("user_phone") for item in state.get_interruptions(): if item.name == "hand_off_ceo": state.approve(item) payload = self.build_messages(history, question) else: context = self.build_context(question, history) payload = self.build_messages(history, question) started_at = perf_counter() ttft_ms: int | None = None # Fire guardrail and assistant in parallel guardrail_task = asyncio.create_task( run_guardrail(question) ) filter_state = CitationTagStreamFilter() tool_calls: dict[str, str] = {} token_buffer: list[str] = [] pending_events: list[tuple[str, dict[str, Any]]] = [] final_output = "" streamed_visible = False ceo_notification_emitted = False guardrail_resolved = False is_blocked = False result = None try: if resume_data: result = Runner.run_streamed(self.assistant_agent, state) else: result = Runner.run_streamed( self.assistant_agent, input=payload, context=context, max_turns=6, ) async for event in result.stream_events(): # Non-blocking poll: has guardrail resolved? if not guardrail_resolved and guardrail_task.done(): guardrail_resolved = True is_blocked = guardrail_task.result() if is_blocked: result.cancel() break # Flush buffered status events then tokens for evt_type, evt_data in pending_events: yield evt_type, evt_data pending_events.clear() if token_buffer: ttft_ms = max(1, round((perf_counter() - started_at) * 1000)) yield "perf", {"ttft_ms": ttft_ms} for tok in token_buffer: streamed_visible = True yield "token", {"delta": tok} token_buffer.clear() # Process event if event.type == "run_item_stream_event": if event.name == "tool_called" and isinstance(event.item, ToolCallItem): tool_name = ToolEventInspector.tool_name(event.item) tool_call_id = ToolEventInspector.tool_call_id_from_call(event.item) if tool_name and tool_call_id: tool_calls[tool_call_id] = tool_name if tool_name == "retrieve_brand_context": evt: tuple[str, dict[str, Any]] = ("status", {"stage": "retrieval_start"}) if guardrail_resolved: yield evt[0], evt[1] else: pending_events.append(evt) continue if event.name == "tool_output" and isinstance(event.item, ToolCallOutputItem): tool_call_id = ToolEventInspector.tool_call_id_from_output(event.item) tool_name = tool_calls.get(tool_call_id or "") if tool_name == "retrieve_brand_context": evt = ( "status", { "stage": "retrieval_end", "sources": [item.model_dump() for item in context.citations], "handoff": context.should_handoff, "contact": context.contact.model_dump() if context.contact else None, }, ) if guardrail_resolved: yield evt[0], evt[1] else: pending_events.append(evt) if tool_name == "hand_off_ceo" and context.email_notification and not ceo_notification_emitted: ceo_notification_emitted = True evt = ( "ceo_email_sent", { "contact": context.contact.model_dump() if context.contact else None, "email_notification": context.email_notification.model_dump(), }, ) if guardrail_resolved: yield evt[0], evt[1] else: pending_events.append(evt) continue if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): visible_delta = filter_state.feed(event.data.delta) if not visible_delta: continue if guardrail_resolved: if ttft_ms is None: ttft_ms = max(1, round((perf_counter() - started_at) * 1000)) yield "perf", {"ttft_ms": ttft_ms} streamed_visible = True yield "token", {"delta": visible_delta} else: token_buffer.append(visible_delta) # Assistant finished — if guardrail hasn't resolved yet, await it now if not guardrail_resolved: is_blocked = await guardrail_task guardrail_resolved = True if not is_blocked: # Flush buffered events and tokens for evt_type, evt_data in pending_events: yield evt_type, evt_data if token_buffer: if ttft_ms is None: ttft_ms = max(1, round((perf_counter() - started_at) * 1000)) yield "perf", {"ttft_ms": ttft_ms} for tok in token_buffer: streamed_visible = True yield "token", {"delta": tok} if is_blocked: fallback = input_guardrail_fallback() latency_ms = max(1, round((perf_counter() - started_at) * 1000)) if ttft_ms is None: ttft_ms = latency_ms yield "perf", {"ttft_ms": ttft_ms} yield "token", {"delta": fallback} run_result = self._build_result( question=question, context=context, raw_output="", ttft_ms=ttft_ms, latency_ms=latency_ms, fallback_answer=fallback, ) else: if result and result.interruptions: state_json_dict = result.to_state().to_json() yield "interrupt", { "conversation_id": resolved_conversation_id, "state_json": json.dumps(state_json_dict, default=lambda x: list(x) if isinstance(x, set) else x), "interruptions": [ {"name": i.name, "arguments": i.arguments} for i in result.interruptions ] } return final_output = str(result.final_output or "") + filter_state.flush() latency_ms = max(1, round((perf_counter() - started_at) * 1000)) run_result = self._build_result( question=question, context=context, raw_output=final_output, ttft_ms=ttft_ms, latency_ms=latency_ms, ) except Exception: logger.exception("Agent streaming failed") guardrail_task.cancel() if result is not None: result.cancel() fallback = system_error_fallback() latency_ms = max(1, round((perf_counter() - started_at) * 1000)) if not streamed_visible: if ttft_ms is None: ttft_ms = latency_ms yield "perf", {"ttft_ms": ttft_ms} yield "token", {"delta": fallback} run_result = self._build_result( question=question, context=context, raw_output="", ttft_ms=ttft_ms, latency_ms=latency_ms, fallback_answer=fallback, ) yield "done", { "conversation_id": resolved_conversation_id, "content": run_result.content, "segments": [segment.model_dump() for segment in run_result.segments], "citations": [item.model_dump() for item in run_result.citations], "handoff": run_result.should_handoff, "contact": run_result.contact.model_dump() if run_result.contact else None, "email_notification": run_result.email_notification.model_dump() if run_result.email_notification else None, "ttft_ms": run_result.ttft_ms, "latency_ms": run_result.latency_ms, }