Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| from collections.abc import AsyncGenerator | |
| from time import perf_counter | |
| from typing import Any | |
| import json | |
| import uuid | |
| from agents import Agent, ModelSettings, Runner, RunState | |
| from agents.items import ToolCallItem, ToolCallOutputItem | |
| from openai.types.responses import ResponseTextDeltaEvent | |
| from src.agents.state import AgentContext, AgentRunResult | |
| from src.agents.flow import run_guardrail | |
| from src.utils.message_builder import MessageBuilder | |
| from src.agents.prompts import get_prompt_bundle | |
| from src.utils.tool_event_inspector import ToolEventInspector | |
| from src.agents.tools import hand_off_ceo, retrieve_brand_context | |
| from src.utils.agent_utils import ( | |
| insufficiency_fallback, | |
| input_guardrail_fallback, | |
| system_error_fallback, | |
| ) | |
| from src.schemas import ChatHistoryMessage, ChatTextSegment | |
| from src.services.citations import CitationTagStreamFilter, parse_citation_segments | |
| from src.services.llm import get_chat_model | |
| logger = logging.getLogger(__name__) | |
| class AgentService: | |
| def __init__(self) -> None: | |
| self._assistant_agent: Agent[AgentContext] | None = None | |
| def assistant_agent(self) -> Agent[AgentContext]: | |
| if self._assistant_agent is None: | |
| bundle = get_prompt_bundle() | |
| self._assistant_agent = Agent( | |
| name="brand-assistant", | |
| instructions=bundle.system_prompt, | |
| model=get_chat_model(), | |
| model_settings=ModelSettings( | |
| parallel_tool_calls=False, | |
| ), | |
| tools=[retrieve_brand_context, hand_off_ceo], | |
| ) | |
| return self._assistant_agent | |
| def build_context(self, question: str, history: list[ChatHistoryMessage]) -> AgentContext: | |
| return AgentContext( | |
| question=question, | |
| message_count=len(history) + 1, | |
| prompt_bundle=get_prompt_bundle(), | |
| ) | |
| def build_messages(self, history: list[ChatHistoryMessage], question: str) -> list[dict[str, str]]: | |
| return MessageBuilder.build_input_items(history, question) | |
| def _normalize_resume_state_payload(node: Any) -> None: | |
| """Normalize persisted message parts for SDK resume compatibility.""" | |
| if isinstance(node, dict): | |
| # Some persisted chat-completions items include a message envelope like: | |
| # {"role":"assistant","content":[{"type":"text","text":"..."}],"status":"completed"} | |
| # Normalize it into the simpler structure expected by the converter. | |
| role = node.get("role") | |
| content = node.get("content") | |
| if role in {"assistant", "user", "system"} and isinstance(content, list): | |
| flattened_parts: list[str] = [] | |
| for part in content: | |
| if isinstance(part, dict) and part.get("type") in {"text", "output_text"}: | |
| text_value = part.get("text") | |
| if isinstance(text_value, str): | |
| flattened_parts.append(text_value) | |
| if flattened_parts: | |
| node["content"] = "".join(flattened_parts) | |
| node.pop("status", None) | |
| if node.get("type") == "output_text" and "text" in node: | |
| node["type"] = "text" | |
| for value in node.values(): | |
| AgentService._normalize_resume_state_payload(value) | |
| return | |
| if isinstance(node, list): | |
| for value in node: | |
| AgentService._normalize_resume_state_payload(value) | |
| async def _load_resume_state(self, state_json: str) -> RunState: | |
| state_dict = json.loads(state_json) | |
| self._normalize_resume_state_payload(state_dict) | |
| state = await RunState.from_json( | |
| self.assistant_agent, | |
| state_dict, | |
| context_deserializer=lambda x: AgentContext(**x), | |
| ) | |
| if isinstance(state._context.context.citation_ids, list): | |
| state._context.context.citation_ids = set(state._context.context.citation_ids) | |
| return state | |
| def _build_result( | |
| self, | |
| *, | |
| question: str, | |
| context: AgentContext, | |
| raw_output: str, | |
| ttft_ms: int | None, | |
| latency_ms: int, | |
| fallback_answer: str | None = None, | |
| ) -> AgentRunResult: | |
| content = raw_output.strip() | |
| segments = [] | |
| citations = [] | |
| if fallback_answer: | |
| content = fallback_answer | |
| elif context.retrieval_status == "insufficient": | |
| content = content or insufficiency_fallback() | |
| elif not content: | |
| content = system_error_fallback() | |
| if not fallback_answer and (context.citation_ids or "<doc-ref" in content): | |
| parsed = parse_citation_segments(content, allowed_document_ids=context.citation_ids) | |
| content = parsed.content | |
| segments = parsed.segments | |
| citations = parsed.citations | |
| if fallback_answer: | |
| citations = [] | |
| segments = [] | |
| else: | |
| if content and not segments: | |
| segments = [ChatTextSegment(text=content)] | |
| return AgentRunResult( | |
| content=content, | |
| segments=segments, | |
| citations=citations, | |
| should_handoff=context.should_handoff, | |
| fallback_answer=fallback_answer, | |
| ttft_ms=ttft_ms, | |
| latency_ms=latency_ms, | |
| contact=context.contact, | |
| email_notification=context.email_notification, | |
| ) | |
| async def stream( | |
| self, | |
| question: str, | |
| history: list[ChatHistoryMessage], | |
| *, | |
| conversation_id: str | None = None, | |
| resume_data: dict | None = None, | |
| ) -> AsyncGenerator[tuple[str, dict[str, Any]], None]: | |
| resume_conversation_id = resume_data.get("conversation_id") if isinstance(resume_data, dict) else None | |
| resolved_conversation_id = conversation_id or resume_conversation_id or str(uuid.uuid4()) | |
| if resume_data: | |
| state = await self._load_resume_state(resume_data["state_json"]) | |
| context = state._context.context | |
| context.user_email = resume_data.get("user_email") | |
| context.user_name = resume_data.get("user_name") | |
| context.user_phone = resume_data.get("user_phone") | |
| for item in state.get_interruptions(): | |
| if item.name == "hand_off_ceo": | |
| state.approve(item) | |
| payload = self.build_messages(history, question) | |
| else: | |
| context = self.build_context(question, history) | |
| payload = self.build_messages(history, question) | |
| started_at = perf_counter() | |
| ttft_ms: int | None = None | |
| # Fire guardrail and assistant in parallel | |
| guardrail_task = asyncio.create_task( | |
| run_guardrail(question) | |
| ) | |
| filter_state = CitationTagStreamFilter() | |
| tool_calls: dict[str, str] = {} | |
| token_buffer: list[str] = [] | |
| pending_events: list[tuple[str, dict[str, Any]]] = [] | |
| final_output = "" | |
| streamed_visible = False | |
| ceo_notification_emitted = False | |
| guardrail_resolved = False | |
| is_blocked = False | |
| result = None | |
| try: | |
| if resume_data: | |
| result = Runner.run_streamed(self.assistant_agent, state) | |
| else: | |
| result = Runner.run_streamed( | |
| self.assistant_agent, | |
| input=payload, | |
| context=context, | |
| max_turns=6, | |
| ) | |
| async for event in result.stream_events(): | |
| # Non-blocking poll: has guardrail resolved? | |
| if not guardrail_resolved and guardrail_task.done(): | |
| guardrail_resolved = True | |
| is_blocked = guardrail_task.result() | |
| if is_blocked: | |
| result.cancel() | |
| break | |
| # Flush buffered status events then tokens | |
| for evt_type, evt_data in pending_events: | |
| yield evt_type, evt_data | |
| pending_events.clear() | |
| if token_buffer: | |
| ttft_ms = max(1, round((perf_counter() - started_at) * 1000)) | |
| yield "perf", {"ttft_ms": ttft_ms} | |
| for tok in token_buffer: | |
| streamed_visible = True | |
| yield "token", {"delta": tok} | |
| token_buffer.clear() | |
| # Process event | |
| if event.type == "run_item_stream_event": | |
| if event.name == "tool_called" and isinstance(event.item, ToolCallItem): | |
| tool_name = ToolEventInspector.tool_name(event.item) | |
| tool_call_id = ToolEventInspector.tool_call_id_from_call(event.item) | |
| if tool_name and tool_call_id: | |
| tool_calls[tool_call_id] = tool_name | |
| if tool_name == "retrieve_brand_context": | |
| evt: tuple[str, dict[str, Any]] = ("status", {"stage": "retrieval_start"}) | |
| if guardrail_resolved: | |
| yield evt[0], evt[1] | |
| else: | |
| pending_events.append(evt) | |
| continue | |
| if event.name == "tool_output" and isinstance(event.item, ToolCallOutputItem): | |
| tool_call_id = ToolEventInspector.tool_call_id_from_output(event.item) | |
| tool_name = tool_calls.get(tool_call_id or "") | |
| if tool_name == "retrieve_brand_context": | |
| evt = ( | |
| "status", | |
| { | |
| "stage": "retrieval_end", | |
| "sources": [item.model_dump() for item in context.citations], | |
| "handoff": context.should_handoff, | |
| "contact": context.contact.model_dump() if context.contact else None, | |
| }, | |
| ) | |
| if guardrail_resolved: | |
| yield evt[0], evt[1] | |
| else: | |
| pending_events.append(evt) | |
| if tool_name == "hand_off_ceo" and context.email_notification and not ceo_notification_emitted: | |
| ceo_notification_emitted = True | |
| evt = ( | |
| "ceo_email_sent", | |
| { | |
| "contact": context.contact.model_dump() if context.contact else None, | |
| "email_notification": context.email_notification.model_dump(), | |
| }, | |
| ) | |
| if guardrail_resolved: | |
| yield evt[0], evt[1] | |
| else: | |
| pending_events.append(evt) | |
| continue | |
| if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): | |
| visible_delta = filter_state.feed(event.data.delta) | |
| if not visible_delta: | |
| continue | |
| if guardrail_resolved: | |
| if ttft_ms is None: | |
| ttft_ms = max(1, round((perf_counter() - started_at) * 1000)) | |
| yield "perf", {"ttft_ms": ttft_ms} | |
| streamed_visible = True | |
| yield "token", {"delta": visible_delta} | |
| else: | |
| token_buffer.append(visible_delta) | |
| # Assistant finished — if guardrail hasn't resolved yet, await it now | |
| if not guardrail_resolved: | |
| is_blocked = await guardrail_task | |
| guardrail_resolved = True | |
| if not is_blocked: | |
| # Flush buffered events and tokens | |
| for evt_type, evt_data in pending_events: | |
| yield evt_type, evt_data | |
| if token_buffer: | |
| if ttft_ms is None: | |
| ttft_ms = max(1, round((perf_counter() - started_at) * 1000)) | |
| yield "perf", {"ttft_ms": ttft_ms} | |
| for tok in token_buffer: | |
| streamed_visible = True | |
| yield "token", {"delta": tok} | |
| if is_blocked: | |
| fallback = input_guardrail_fallback() | |
| latency_ms = max(1, round((perf_counter() - started_at) * 1000)) | |
| if ttft_ms is None: | |
| ttft_ms = latency_ms | |
| yield "perf", {"ttft_ms": ttft_ms} | |
| yield "token", {"delta": fallback} | |
| run_result = self._build_result( | |
| question=question, | |
| context=context, | |
| raw_output="", | |
| ttft_ms=ttft_ms, | |
| latency_ms=latency_ms, | |
| fallback_answer=fallback, | |
| ) | |
| else: | |
| if result and result.interruptions: | |
| state_json_dict = result.to_state().to_json() | |
| yield "interrupt", { | |
| "conversation_id": resolved_conversation_id, | |
| "state_json": json.dumps(state_json_dict, default=lambda x: list(x) if isinstance(x, set) else x), | |
| "interruptions": [ | |
| {"name": i.name, "arguments": i.arguments} for i in result.interruptions | |
| ] | |
| } | |
| return | |
| final_output = str(result.final_output or "") + filter_state.flush() | |
| latency_ms = max(1, round((perf_counter() - started_at) * 1000)) | |
| run_result = self._build_result( | |
| question=question, | |
| context=context, | |
| raw_output=final_output, | |
| ttft_ms=ttft_ms, | |
| latency_ms=latency_ms, | |
| ) | |
| except Exception: | |
| logger.exception("Agent streaming failed") | |
| guardrail_task.cancel() | |
| if result is not None: | |
| result.cancel() | |
| fallback = system_error_fallback() | |
| latency_ms = max(1, round((perf_counter() - started_at) * 1000)) | |
| if not streamed_visible: | |
| if ttft_ms is None: | |
| ttft_ms = latency_ms | |
| yield "perf", {"ttft_ms": ttft_ms} | |
| yield "token", {"delta": fallback} | |
| run_result = self._build_result( | |
| question=question, | |
| context=context, | |
| raw_output="", | |
| ttft_ms=ttft_ms, | |
| latency_ms=latency_ms, | |
| fallback_answer=fallback, | |
| ) | |
| yield "done", { | |
| "conversation_id": resolved_conversation_id, | |
| "content": run_result.content, | |
| "segments": [segment.model_dump() for segment in run_result.segments], | |
| "citations": [item.model_dump() for item in run_result.citations], | |
| "handoff": run_result.should_handoff, | |
| "contact": run_result.contact.model_dump() if run_result.contact else None, | |
| "email_notification": run_result.email_notification.model_dump() if run_result.email_notification else None, | |
| "ttft_ms": run_result.ttft_ms, | |
| "latency_ms": run_result.latency_ms, | |
| } | |