Spaces:
Running
Running
| import json | |
| import logging | |
| import time | |
| import threading | |
| from collections.abc import Generator | |
| from gradio import ChatMessage | |
| from gradio.components.chatbot import MetadataDict | |
| from openai import OpenAI, APIError | |
| from openai.types.responses import ( | |
| ResponseInputItemParam, | |
| ResponseOutputMessage, | |
| ResponseOutputText, | |
| ) | |
| import config | |
| import prompts | |
| import security | |
| from tools import ToolRegistry | |
| logger = logging.getLogger(__name__) | |
| IN_CHARACTER_ERROR = "Apologies — technical hiccup on my end. Try asking again in a moment." | |
| class _ThoughtAccordion: | |
| """ | |
| Manages a 'Thinking...' accordion in the Gradio chat UI to display reasoning summaries and tool | |
| calls. The accordion is a gr.ChatMessage with a `metadata` attribute, which causes it to render | |
| as a separate bubble attached to the subsequent assistant message. We asynchronously accumulate | |
| reasoning summaries and tool calls/results as the content of this message. | |
| All methods mutate `ui_messages` in place. Caller should yield it after each update. | |
| """ | |
| def __init__(self, ui_messages: list[ChatMessage]): | |
| self._messages = ui_messages | |
| self._msg_index: int | None = None | |
| self._parts: dict[str, str] = {} # ordered dict of 'thoughts' | |
| self._start = time.time() | |
| self._meta: MetadataDict = {"title": "🤔 Thinking...", "status": "pending"} | |
| self.finalized = False | |
| def add_reasoning_delta(self, key: str, delta: str): | |
| """Accumulate streaming reasoning summary text under `key`.""" | |
| self._parts[key] = self._parts.get(key, "") + delta | |
| self._render() | |
| def set_tool_pending(self, item_id: str, name: str | None): | |
| """Show a tool call as in-progress.""" | |
| self._parts[f"t_{item_id}"] = f"🔧 {name or 'tool'}..." | |
| self._render() | |
| def set_tool_result(self, item_id: str, name: str, result: str): | |
| """Update a tool call line with its result.""" | |
| self._parts[f"t_{item_id}"] = f"🔧 {name}: {result}" | |
| self._render() | |
| def finalize(self): | |
| """Keep accordion open; replace ellipsis and spinner with duration.""" | |
| if self._msg_index is not None and not self.finalized: | |
| self.finalized = True | |
| del self._meta["status"] | |
| self._meta["title"] = "🤔 Thinking" | |
| self._meta["duration"] = round(time.time() - self._start, 2) | |
| self._render() | |
| def _render(self): | |
| r"""Turn all `_parts` into \n-separated entries of thought content.""" | |
| assert self._msg_index is None or self._msg_index == 0 | |
| content = "\n".join(self._parts.values()) | |
| msg = ChatMessage(role="assistant", content=content, metadata=self._meta) | |
| if self._msg_index is None: | |
| self._msg_index = len(self._messages) | |
| self._messages.append(msg) | |
| else: | |
| self._messages[self._msg_index] = msg | |
| def _normalize_mixed_history(messages): | |
| """Build a normalized dict of ONLY user and assistant message texts. Drops context | |
| injections (role=developer), tool calls, reasoning summaries, and message metadata.""" | |
| normed = [] | |
| for m in messages: | |
| if isinstance(m, dict) and m.get("role") in ("user", "assistant") and "content" in m: | |
| normed.append({"role": m["role"], "content": m["content"]}) | |
| elif isinstance(m, ResponseOutputMessage) and isinstance(m.content[0], ResponseOutputText): | |
| normed.append({"role": m.role, "content": m.content[0].text}) | |
| return normed | |
| def _summary_notification_daemon( | |
| client: OpenAI, | |
| messages: list[ResponseInputItemParam], | |
| tool_registry: ToolRegistry, | |
| ) -> None: | |
| """ | |
| Summarize the conversation so-far and send a push notification using the 'send_notification' | |
| tool. Intended to run as a daemon thread so it doesn't block user-facing response. | |
| """ | |
| if 'send_notification' not in tool_registry: | |
| logger.warning('cannot send summary notification: send_notification not registered') | |
| return | |
| summary_corpus = _normalize_mixed_history(messages)[-20:] | |
| try: | |
| resp = client.responses.create( | |
| model=config.INFERENCE_MODEL, | |
| instructions=prompts.SUMMARY_NOTIFICATION, | |
| input=summary_corpus, | |
| ) | |
| tool_registry['send_notification']['fn'](message=resp.output_text) | |
| except Exception as e: | |
| logger.error("Background summary notification failed", exc_info=True) | |
| def stream_turn( | |
| client: OpenAI, | |
| api_messages: list[ResponseInputItemParam], | |
| tool_registry: ToolRegistry, | |
| ) -> Generator[tuple[list[ChatMessage], list[ResponseInputItemParam]], None, None]: | |
| """ | |
| Stream a conversation turn, yielding (new_ui_msgs, api_messages) tuples. | |
| Handles tool calls by executing them and re-streaming for the model's next response. | |
| Reasoning summaries and tool usage are shown in a single collapsible `_ThoughtAccordion`. | |
| """ | |
| api_messages = list(api_messages) # shallow copy; caller gets final state via yield | |
| new_ui_msgs: list[ChatMessage] = [] | |
| tools = tool_registry.get_specs() | |
| loop_count = 0 | |
| thinking = _ThoughtAccordion(new_ui_msgs) | |
| while True: | |
| loop_count += 1 | |
| if loop_count > config.MAX_SEQUENTIAL_TOOL_CALLS: | |
| tools = [] | |
| if loop_count > config.MAX_SEQUENTIAL_TOOL_CALLS + 1: | |
| logger.warning("exceeded %s sequential tool calls", config.MAX_SEQUENTIAL_TOOL_CALLS) | |
| break | |
| try: | |
| stream = client.responses.create( | |
| model=config.INFERENCE_MODEL, | |
| input=api_messages, | |
| tools=tools, | |
| stream=True, | |
| ) | |
| except APIError as e: | |
| logger.error("OpenAI call failed: %s: %s", type(e).__name__, e) | |
| new_ui_msgs.append(ChatMessage(role="assistant", content=IN_CHARACTER_ERROR)) | |
| yield new_ui_msgs, api_messages | |
| break | |
| # per-stream state (resets each time we get a new response after tool use) | |
| response_text = "" | |
| response_text_initiated = False | |
| has_tool_calls = False | |
| try: | |
| for event in stream: | |
| if event.type == 'response.reasoning_summary_text.delta': | |
| key = f'r_{loop_count}_{event.output_index}_{event.summary_index}' | |
| thinking.add_reasoning_delta(key, event.delta) | |
| yield new_ui_msgs, api_messages | |
| elif event.type == 'response.function_call_arguments.done': | |
| thinking.set_tool_pending(event.item_id, event.name) | |
| yield new_ui_msgs, api_messages | |
| elif event.type == 'response.output_text.delta': | |
| if not response_text_initiated: | |
| response_text_initiated = True | |
| thinking.finalize() | |
| new_ui_msgs.append(ChatMessage(role="assistant", content="")) | |
| response_text += event.delta | |
| new_ui_msgs[-1].content = response_text | |
| yield new_ui_msgs, api_messages | |
| elif event.type == 'response.completed': | |
| response = event.response | |
| api_messages.extend(response.output) # type: ignore | |
| for item in response.output: | |
| if item.type != "function_call": | |
| continue | |
| if item.name not in tool_registry: | |
| logger.warning('model tried calling unknown tool %s with args: %s', | |
| item.name, item.arguments) | |
| continue | |
| has_tool_calls = True | |
| # Rate-limit notification tool to prevent abuse | |
| if item.name == 'send_notification': | |
| if not security.rate_limiter.check_notification_rate(): | |
| tool_result = "Notification rate limit reached. Try again later." | |
| else: | |
| tool_result = tool_registry[item.name]['fn'](**json.loads(item.arguments)) | |
| else: | |
| tool_result = tool_registry[item.name]['fn'](**json.loads(item.arguments)) | |
| api_messages.append({ | |
| "type": "function_call_output", | |
| "call_id": item.call_id, | |
| "output": json.dumps(tool_result), | |
| }) | |
| thinking.set_tool_result(item.id, item.name, tool_result) # type: ignore | |
| yield new_ui_msgs, api_messages | |
| except APIError as e: | |
| logger.error("OpenAI stream error: %s: %s", type(e).__name__, e) | |
| new_ui_msgs.append(ChatMessage(role="assistant", content=IN_CHARACTER_ERROR)) | |
| yield new_ui_msgs, api_messages | |
| break | |
| if not has_tool_calls: | |
| # --- OUTPUT SECURITY GATE: Filter completed response for disclosure --- | |
| if response_text: | |
| filtered = security.filter_output(response_text) | |
| if filtered != response_text: | |
| new_ui_msgs[-1].content = filtered | |
| yield new_ui_msgs, api_messages | |
| break | |
| # else: tool calls were answered, so we loop to stream another response | |
| # cleanup after `break` | |
| if not thinking.finalized: | |
| thinking.finalize() | |
| yield new_ui_msgs, api_messages | |
| # every other user message, update Lamonte with a conversation summary (background) | |
| user_m_count = len([m for m in api_messages if isinstance(m, dict) and m.get('role') == 'user']) | |
| if user_m_count % 2 == 0: | |
| threading.Thread( | |
| target=_summary_notification_daemon, | |
| args=(client, api_messages, tool_registry), | |
| daemon=True, | |
| ).start() | |