Spaces:
Paused
Paused
File size: 11,051 Bytes
a5784e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 | import asyncio
from typing import Any, Callable, Dict, List, Optional, Tuple
from playwright.async_api import expect as expect_async
from browser_utils.operations import (
_get_final_response_content,
_wait_for_response_completion,
save_error_snapshot,
)
from config import (
EDIT_MESSAGE_BUTTON_SELECTOR,
PROMPT_TEXTAREA_SELECTOR,
RESPONSE_CONTAINER_SELECTOR,
RESPONSE_TEXT_SELECTOR,
SUBMIT_BUTTON_SELECTOR,
)
from config.settings import FUNCTION_CALLING_DEBUG
from logging_utils import set_request_id
from models import ClientDisconnectedError
from .base import BaseController
class ResponseController(BaseController):
"""Handles retrieval of AI responses."""
async def get_response(
self,
check_client_disconnected: Callable,
prompt_length: int = 0,
timeout: Optional[float] = None,
) -> str:
"""Retrieve response content."""
set_request_id(self.req_id)
self.logger.debug("[Response] Waiting for and retrieving response...")
try:
# Wait for response container
response_container_locator = self.page.locator(
RESPONSE_CONTAINER_SELECTOR
).last
response_element_locator = response_container_locator.locator(
RESPONSE_TEXT_SELECTOR
)
self.logger.debug(
"[Response] Waiting for response element to be attached to DOM..."
)
await expect_async(response_element_locator).to_be_attached(timeout=90000)
await self._check_disconnect(
check_client_disconnected,
"Retrieve Response - Response element attached",
)
# Wait for response completion
submit_button_locator = self.page.locator(SUBMIT_BUTTON_SELECTOR)
edit_button_locator = self.page.locator(EDIT_MESSAGE_BUTTON_SELECTOR)
input_field_locator = self.page.locator(PROMPT_TEXTAREA_SELECTOR)
self.logger.debug("[Response] Waiting for response completion...")
completion_detected = await _wait_for_response_completion(
self.page,
input_field_locator,
submit_button_locator,
edit_button_locator,
self.req_id,
check_client_disconnected,
None,
prompt_length=prompt_length,
timeout=timeout,
)
if not completion_detected:
self.logger.warning(
"Response completion detection failed, attempting to retrieve current content"
)
else:
self.logger.debug("[Response] Response completion detection successful")
# Get final response content
final_content = await _get_final_response_content(
self.page, self.req_id, check_client_disconnected
)
if not final_content or not final_content.strip():
self.logger.warning("Retrieved response content is empty")
await save_error_snapshot(f"empty_response_{self.req_id}")
# Do not raise exception, return empty content to let caller handle
return ""
self.logger.debug(
f"[Response] Successfully retrieved content ({len(final_content)} chars)"
)
return final_content
except Exception as e:
if isinstance(e, asyncio.CancelledError):
self.logger.info("Retrieve response task cancelled")
raise
self.logger.error(f"Error retrieving response: {e}")
if not isinstance(e, ClientDisconnectedError):
await save_error_snapshot(f"get_response_error_{self.req_id}")
raise
async def ensure_generation_stopped(
self, check_client_disconnected: Callable
) -> None:
"""
Ensure generation has stopped.
If submit button is still enabled, click it to stop generation.
Wait until submit button becomes disabled.
"""
submit_button_locator = self.page.locator(SUBMIT_BUTTON_SELECTOR)
# Check client connection status
check_client_disconnected("Ensure generation stopped - pre-check")
await asyncio.sleep(0.5) # Give UI time to update
# Check if button is still enabled, if so click to stop
try:
is_button_enabled = await submit_button_locator.is_enabled(timeout=2000)
if is_button_enabled:
# Button still enabled after stream completion, click to stop
self.logger.debug(
"[Cleanup] Submit button state: ENABLED -> Clicking stop"
)
await submit_button_locator.click(timeout=5000, force=True)
else:
self.logger.debug(
"[Cleanup] Submit button state: DISABLED (no action needed)"
)
except Exception as button_check_err:
if isinstance(button_check_err, asyncio.CancelledError):
raise
self.logger.warning(f"Failed to check button state: {button_check_err}")
# Wait for button to be disabled
try:
await expect_async(submit_button_locator).to_be_disabled(timeout=30000)
except Exception as e:
if isinstance(e, asyncio.CancelledError):
raise
self.logger.warning(f"Timeout or error ensuring generation stopped: {e}")
# Do not raise even on timeout as this is just a cleanup step
async def detect_function_calls(
self,
check_client_disconnected: Callable,
) -> bool:
"""
Detect if the current response contains function calls.
This is a fast check without full parsing. Uses the
FunctionCallResponseParser for DOM-based detection.
Args:
check_client_disconnected: Callback to check client connection.
Returns:
True if function calls are detected, False otherwise.
"""
try:
from api_utils.utils_ext.function_call_response_parser import (
FunctionCallResponseParser,
)
parser = FunctionCallResponseParser(
page=self.page,
logger=self.logger,
req_id=self.req_id,
)
return await parser.detect_function_calls(check_client_disconnected)
except Exception as e:
if FUNCTION_CALLING_DEBUG:
self.logger.debug(
f"[{self.req_id}] Error detecting function calls: {e}"
)
return False
async def parse_function_calls(
self,
check_client_disconnected: Callable,
) -> Tuple[bool, List[Dict[str, Any]], str]:
"""
Parse function calls from the current response.
Args:
check_client_disconnected: Callback to check client connection.
Returns:
Tuple of:
- has_function_calls: Whether any function calls were found
- function_calls: List of function call dicts with 'name' and 'params'
- text_content: Any remaining text content
"""
try:
from api_utils.utils_ext.function_call_response_parser import (
FunctionCallResponseParser,
)
parser = FunctionCallResponseParser(
page=self.page,
logger=self.logger,
req_id=self.req_id,
)
result = await parser.parse_function_calls(check_client_disconnected)
# Convert ParsedFunctionCall to dict format
function_calls: List[Dict[str, Any]] = []
for fc in result.function_calls:
function_calls.append(
{
"name": fc.name,
"params": fc.arguments,
}
)
return result.has_function_calls, function_calls, result.text_content
except Exception as e:
if FUNCTION_CALLING_DEBUG:
self.logger.error(f"[{self.req_id}] Error parsing function calls: {e}")
return False, [], ""
async def get_response_with_function_calls(
self,
check_client_disconnected: Callable,
prompt_length: int = 0,
timeout: Optional[float] = None,
) -> Dict[str, Any]:
"""
Retrieve response content with function call detection.
This method extends get_response by also checking for function calls
in the response and returning them in a structured format.
Args:
check_client_disconnected: Callback to check client connection.
prompt_length: Length of the prompt for timeout calculation.
timeout: Optional explicit timeout.
Returns:
Dict with:
- content: Text content of the response
- has_function_calls: Whether function calls were detected
- function_calls: List of function call dicts
- raw_content: The raw content string
"""
set_request_id(self.req_id)
if FUNCTION_CALLING_DEBUG:
self.logger.debug(
"[Response] Getting response with function call detection..."
)
result: Dict[str, Any] = {
"content": "",
"has_function_calls": False,
"function_calls": [],
"raw_content": "",
}
try:
# Get the raw response content first
raw_content = await self.get_response(
check_client_disconnected,
prompt_length=prompt_length,
timeout=timeout,
)
result["raw_content"] = raw_content
# Check for function calls
has_fc, function_calls, text_content = await self.parse_function_calls(
check_client_disconnected
)
if has_fc:
if FUNCTION_CALLING_DEBUG:
self.logger.info(
f"[{self.req_id}] Detected {len(function_calls)} function call(s) in response"
)
result["has_function_calls"] = True
result["function_calls"] = function_calls
result["content"] = text_content
else:
result["content"] = raw_content
except Exception as e:
if isinstance(e, asyncio.CancelledError):
raise
if FUNCTION_CALLING_DEBUG:
self.logger.error(
f"[{self.req_id}] Error getting response with FC: {e}"
)
if not isinstance(e, ClientDisconnectedError):
await save_error_snapshot(f"get_response_fc_error_{self.req_id}")
raise
return result
|